In [1]:
import os
from os.path import join
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
import torch.nn as nn
import torchvision
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.utils import train_fn
from astra.torch.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt

device = "cuda"

In [2]:
path = "/home/vannsh.jani/brick_kilns/ssl_exp/data"
x_train = torch.load(join(path, "ban_x_train.pt"))
y_train = torch.load(join(path, "ban_y_train.pt"))
x_test = torch.load(join(path, "ban_x_test.pt"))
y_test = torch.load(join(path, "ban_y_test.pt"))

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
x_b = torch.cat([x_train, x_test], dim=0).to(device)
y_b = torch.cat([y_train, y_test], dim=0).to(device)

print(x_b.shape, y_b.shape)

torch.Size([19124, 3, 224, 224]) torch.Size([19124]) torch.Size([6375, 3, 224, 224]) torch.Size([6375])
torch.Size([25499, 3, 224, 224]) torch.Size([25499])


In [3]:
x_d = torch.load(join(path, "delhi_test_images_50.pt")).to(device)
y_d = torch.load(join(path, "delhi_test_labels_50.pt")).to(device)
print(x_d.shape, y_d.shape)

torch.Size([5013, 3, 224, 224]) torch.Size([5013])


In [4]:
# Downstream model
class DownstreamModel(nn.Module):
    def __init__(self,num_classes):
        super().__init__()
        resnet = torchvision.models.efficientnet_b0(weights=None)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.last_layer = list(resnet.children())[-1]
        # print(dict(self.resnet.named_parameters()).keys())
        self.resnet.load_state_dict({k.replace("resnet.", ""): v for k, v in torch.load("/home/vannsh.jani/brick_kilns/ssl_exp/dino/dino_bd_75_delhi_test_train5_final_table.pth").items() if k not in ["fc.weight", "fc.bias"]}) # load different weights
        self.last_layer[1] = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.resnet(x).squeeze(-2, -1)
        # print(x.shape)
        x = self.last_layer(x)
        return x

In [5]:
model = DownstreamModel(2).to(device)
model(torch.rand(2, 3, 224, 224).to(device)).shape

torch.Size([2, 2])

In [6]:
lr = 1e-4
epochs = 10
loss_fn = nn.CrossEntropyLoss()

train_losses = []
test_losses = []
for epoch in range(epochs):
    print("Epoch:", epoch)
    model.train()
    iter_losses, epoch_losses = train_fn(model, loss_fn, x_b, y_b, lr=lr, epochs=1, batch_size=64)
    
    model.eval()
    
    with torch.no_grad():
        local_losses = []
        for i in range(0, len(x_b), 64):
            y_pred = model(x_b[i:i+64])
            loss = loss_fn(y_pred, y_b[i:i+64].long())
            local_losses.append(loss.item())

        train_losses.append(sum(local_losses)/len(local_losses))
    
    with torch.no_grad():
        y_pred = model(x_d)
        test_loss = loss_fn(y_pred, y_d.long())
        test_losses.append(test_loss.item())
        y_pred = torch.argmax(y_pred, dim=1)
        print("Accuracy:", accuracy_score(y_d, y_pred))
        print("Precision:", precision_score(y_d, y_pred))
        print("Recall:", recall_score(y_d, y_pred))
        print("F1 Score:", f1_score(y_d, y_pred))
        print("Train Loss:", train_losses[-1])
        print("Test Loss:", test_losses[-1])
        
    torch.cuda.empty_cache()

Epoch: 0


Loss: 0.08144404: 100%|██████████| 1/1 [00:24<00:00, 24.68s/it]


Accuracy: tensor(0.9575, device='cuda:0')
Precision: tensor(0.6315, device='cuda:0')
Recall: tensor(0.9400, device='cuda:0')
F1 Score: tensor(0.7555, device='cuda:0')
Train Loss: 0.020741720507644994
Test Loss: 0.13553185760974884
Epoch: 1


Loss: 0.01816072: 100%|██████████| 1/1 [00:24<00:00, 24.51s/it]


Accuracy: tensor(0.9575, device='cuda:0')
Precision: tensor(0.6353, device='cuda:0')
Recall: tensor(0.9350, device='cuda:0')
F1 Score: tensor(0.7566, device='cuda:0')
Train Loss: 0.011514820287151164
Test Loss: 0.1623053103685379
Epoch: 2


Loss: 0.01080246: 100%|██████████| 1/1 [00:24<00:00, 24.49s/it]


Accuracy: tensor(0.9575, device='cuda:0')
Precision: tensor(0.6353, device='cuda:0')
Recall: tensor(0.9350, device='cuda:0')
F1 Score: tensor(0.7566, device='cuda:0')
Train Loss: 0.0070075427926318224
Test Loss: 0.18949022889137268
Epoch: 3


Loss: 0.00567212: 100%|██████████| 1/1 [00:24<00:00, 24.48s/it]


Accuracy: tensor(0.9565, device='cuda:0')
Precision: tensor(0.6219, device='cuda:0')
Recall: tensor(0.9391, device='cuda:0')
F1 Score: tensor(0.7483, device='cuda:0')
Train Loss: 0.00406423602208704
Test Loss: 0.2482527643442154
Epoch: 4


Loss: 0.00284343: 100%|██████████| 1/1 [00:24<00:00, 24.48s/it]


Accuracy: tensor(0.9585, device='cuda:0')
Precision: tensor(0.6372, device='cuda:0')
Recall: tensor(0.9459, device='cuda:0')
F1 Score: tensor(0.7615, device='cuda:0')
Train Loss: 0.0027631315152896953
Test Loss: 0.2856736183166504
Epoch: 5


Loss: 0.00195885: 100%|██████████| 1/1 [00:24<00:00, 24.49s/it]


Accuracy: tensor(0.9565, device='cuda:0')
Precision: tensor(0.6411, device='cuda:0')
Recall: tensor(0.9151, device='cuda:0')
F1 Score: tensor(0.7540, device='cuda:0')
Train Loss: 0.0021288321425856366
Test Loss: 0.3742600679397583
Epoch: 6


Loss: 0.00098726: 100%|██████████| 1/1 [00:24<00:00, 24.49s/it]


Accuracy: tensor(0.9591, device='cuda:0')
Precision: tensor(0.6488, device='cuda:0')
Recall: tensor(0.9389, device='cuda:0')
F1 Score: tensor(0.7673, device='cuda:0')
Train Loss: 0.0021756362240938614
Test Loss: 0.43321630358695984
Epoch: 7


Loss: 0.00125853: 100%|██████████| 1/1 [00:24<00:00, 24.48s/it]


Accuracy: tensor(0.9529, device='cuda:0')
Precision: tensor(0.5854, device='cuda:0')
Recall: tensor(0.9385, device='cuda:0')
F1 Score: tensor(0.7210, device='cuda:0')
Train Loss: 0.002864020850487401
Test Loss: 0.5094847679138184
Epoch: 8


Loss: 0.00145421: 100%|██████████| 1/1 [00:24<00:00, 24.49s/it]


Accuracy: tensor(0.9611, device='cuda:0')
Precision: tensor(0.6756, device='cuda:0')
Recall: tensor(0.9312, device='cuda:0')
F1 Score: tensor(0.7831, device='cuda:0')
Train Loss: 0.0017227483522373383
Test Loss: 0.44481340050697327
Epoch: 9


Loss: 0.00083111: 100%|██████████| 1/1 [00:24<00:00, 24.49s/it]


Accuracy: tensor(0.9579, device='cuda:0')
Precision: tensor(0.6372, device='cuda:0')
Recall: tensor(0.9379, device='cuda:0')
F1 Score: tensor(0.7589, device='cuda:0')
Train Loss: 0.00227797791405225
Test Loss: 0.5771077871322632
