In [1]:
import os
from os.path import join
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
import torchvision
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from astra.torch.utils import train_fn
from astra.torch.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt

device = "cuda"

In [2]:
path = "/home/vannsh.jani/brick_kilns/ssl_exp/data"
x_train = torch.load(join(path, "ban_x_train.pt"))
y_train = torch.load(join(path, "ban_y_train.pt"))
x_test = torch.load(join(path, "ban_x_test.pt"))
y_test = torch.load(join(path, "ban_y_test.pt"))

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
x_b = torch.cat([x_train, x_test], dim=0).to(device)
y_b = torch.cat([y_train, y_test], dim=0).to(device)

print(x_b.shape, y_b.shape)

torch.Size([19124, 3, 224, 224]) torch.Size([19124]) torch.Size([6375, 3, 224, 224]) torch.Size([6375])
torch.Size([25499, 3, 224, 224]) torch.Size([25499])


In [3]:
x_d = torch.load(join(path, "delhi_test_images_50.pt")).to(device)
y_d = torch.load(join(path, "delhi_test_labels_50.pt")).to(device)
print(x_d.shape, y_d.shape)

torch.Size([5013, 3, 224, 224]) torch.Size([5013])


In [4]:
# Downstream model
class DownstreamModel(nn.Module):
    def __init__(self,num_classes):
        super().__init__()
        eff = torchvision.models.efficientnet_b0(weights=None)
        self.eff = nn.Sequential(*list(eff.children())[:-1])
        self.last_layer = list(eff.children())[-1]
        self.eff.load_state_dict(torch.load("/home/vannsh.jani/brick_kilns/ssl_exp/simsiam/eff_simsiam_bd_150.pth")) # load different weights
        self.last_layer[1] = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.eff(x).squeeze(-2, -1)
        # print(x.shape)
        x = self.last_layer(x)
        return x

In [5]:
model = DownstreamModel(2).to(device)
model(torch.rand(2, 3, 224, 224).to(device)).shape

torch.Size([2, 2])

In [6]:
lr = 1e-4
epochs = 10
loss_fn = nn.CrossEntropyLoss()

train_losses = []
test_losses = []
for epoch in range(epochs):
    print("Epoch:", epoch)
    model.train()
    iter_losses, epoch_losses = train_fn(model, loss_fn, x_b, y_b, lr=lr, epochs=1, batch_size=64)
    
    model.eval()
    
    with torch.no_grad():
        local_losses = []
        for i in range(0, len(x_b), 64):
            y_pred = model(x_b[i:i+64])
            loss = loss_fn(y_pred, y_b[i:i+64].long())
            local_losses.append(loss.item())

        train_losses.append(sum(local_losses)/len(local_losses))
    
    with torch.no_grad():
        y_pred = model(x_d)
        test_loss = loss_fn(y_pred, y_d.long())
        test_losses.append(test_loss.item())
        y_pred = torch.argmax(y_pred, dim=1)
        print("Accuracy:", accuracy_score(y_d, y_pred))
        print("Precision:", precision_score(y_d, y_pred))
        print("Recall:", recall_score(y_d, y_pred))
        print("F1 Score:", f1_score(y_d, y_pred))
        print("Train Loss:", train_losses[-1])
        print("Test Loss:", test_losses[-1])
        
    torch.cuda.empty_cache()

Epoch: 0


Loss: 0.10774184: 100%|██████████| 1/1 [00:24<00:00, 24.50s/it]


Accuracy: tensor(0.9364, device='cuda:0')
Precision: tensor(0.7006, device='cuda:0')
Recall: tensor(0.6913, device='cuda:0')
F1 Score: tensor(0.6959, device='cuda:0')
Train Loss: 0.026502915033673805
Test Loss: 0.17517372965812683
Epoch: 1


Loss: 0.03058052: 100%|██████████| 1/1 [00:24<00:00, 24.47s/it]


Accuracy: tensor(0.9370, device='cuda:0')
Precision: tensor(0.7313, device='cuda:0')
Recall: tensor(0.6840, device='cuda:0')
F1 Score: tensor(0.7069, device='cuda:0')
Train Loss: 0.013407692697306484
Test Loss: 0.1879873126745224
Epoch: 2


Loss: 0.01440878: 100%|██████████| 1/1 [00:24<00:00, 24.42s/it]


Accuracy: tensor(0.9517, device='cuda:0')
Precision: tensor(0.6833, device='cuda:0')
Recall: tensor(0.8222, device='cuda:0')
F1 Score: tensor(0.7463, device='cuda:0')
Train Loss: 0.005365202769495309
Test Loss: 0.18698406219482422
Epoch: 3


Loss: 0.00726419: 100%|██████████| 1/1 [00:24<00:00, 24.45s/it]


Accuracy: tensor(0.8580, device='cuda:0')
Precision: tensor(0.7428, device='cuda:0')
Recall: tensor(0.4010, device='cuda:0')
F1 Score: tensor(0.5209, device='cuda:0')
Train Loss: 0.0034546510933410844
Test Loss: 0.4265051484107971
Epoch: 4


Loss: 0.00546201: 100%|██████████| 1/1 [00:24<00:00, 24.43s/it]


Accuracy: tensor(0.9543, device='cuda:0')
Precision: tensor(0.6891, device='cuda:0')
Recall: tensor(0.8427, device='cuda:0')
F1 Score: tensor(0.7582, device='cuda:0')
Train Loss: 0.002141347480895441
Test Loss: 0.24830669164657593
Epoch: 5


Loss: 0.00369552: 100%|██████████| 1/1 [00:24<00:00, 24.43s/it]


Accuracy: tensor(0.9074, device='cuda:0')
Precision: tensor(0.7006, device='cuda:0')
Recall: tensor(0.5423, device='cuda:0')
F1 Score: tensor(0.6114, device='cuda:0')
Train Loss: 0.00565757939775699
Test Loss: 0.4959467649459839
Epoch: 6


Loss: 0.00488550: 100%|██████████| 1/1 [00:24<00:00, 24.43s/it]


Accuracy: tensor(0.9579, device='cuda:0')
Precision: tensor(0.6526, device='cuda:0')
Recall: tensor(0.9189, device='cuda:0')
F1 Score: tensor(0.7632, device='cuda:0')
Train Loss: 0.0016260646667219053
Test Loss: 0.30913805961608887
Epoch: 7


Loss: 0.00346739: 100%|██████████| 1/1 [00:24<00:00, 24.43s/it]


Accuracy: tensor(0.9501, device='cuda:0')
Precision: tensor(0.7274, device='cuda:0')
Recall: tensor(0.7782, device='cuda:0')
F1 Score: tensor(0.7520, device='cuda:0')
Train Loss: 0.001043816561094248
Test Loss: 0.2977580428123474
Epoch: 8


Loss: 0.00261428: 100%|██████████| 1/1 [00:24<00:00, 24.42s/it]


Accuracy: tensor(0.9507, device='cuda:0')
Precision: tensor(0.6449, device='cuda:0')
Recall: tensor(0.8442, device='cuda:0')
F1 Score: tensor(0.7312, device='cuda:0')
Train Loss: 0.0016808339531327878
Test Loss: 0.4634805917739868
Epoch: 9


Loss: 0.00317017: 100%|██████████| 1/1 [00:24<00:00, 24.44s/it]


Accuracy: tensor(0.9553, device='cuda:0')
Precision: tensor(0.6449, device='cuda:0')
Recall: tensor(0.8960, device='cuda:0')
F1 Score: tensor(0.7500, device='cuda:0')
Train Loss: 0.0016162050478654027
Test Loss: 0.35397353768348694
