In [1]:
import os
import random
from PIL import Image
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models

In [2]:
team_id = 15
team_name = "loSSLess"
email_address = "vvb238@nyu.edu"

In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        self.num_images = len(os.listdir(self.image_dir))

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [4]:
train_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [5]:
trainset = CustomDataset(root='../input/supervised-dataset/supervised_dataset', split="train", transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

In [13]:
import torchvision

In [14]:
class SwaVRes34(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder=net=torch.nn.Sequential(*(list(torchvision.models.resnet34(pretrained=False).children()))[:-1])
        checkpoint = torch.load(f"../input/fastai-swav/encoder.pth",map_location=lambda storage, loc: storage)
        self.encoder.load_state_dict(checkpoint)
        self.classifier=nn.Linear(in_features=512,out_features=800)
    
    def forward(self,x):
        rep=self.encoder(x).view(x.shape[0],-1)
        y_hat=self.classifier(rep)
        return y_hat

In [15]:
model=SwaVRes34()

In [16]:
net = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

In [17]:
print('Start Training')
numOfBatches = len(trainset)/trainloader.batch_size

for epoch in range(50):
    net.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(trainloader), total=int(numOfBatches)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()

        outputs = net(inputs)
#         break
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
#         if i % 10 == 9:    # print every 10 mini-batches
#             print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
#             running_loss = 0.0
#     break
    print("Loss at epoch", epoch, "is", running_loss/numOfBatches)
    scheduler.step(running_loss)

print('Finished Training')

Start Training


100%|██████████| 400/400 [03:09<00:00,  2.11it/s]

Loss at epoch 0 is 6.697576781511307



100%|██████████| 400/400 [01:25<00:00,  4.70it/s]

Loss at epoch 1 is 6.4966594660282135



100%|██████████| 400/400 [01:26<00:00,  4.64it/s]

Loss at epoch 2 is 6.423213518857956



100%|██████████| 400/400 [01:26<00:00,  4.62it/s]

Loss at epoch 3 is 6.379365441799163



100%|██████████| 400/400 [01:26<00:00,  4.61it/s]

Loss at epoch 4 is 6.349410374164581



100%|██████████| 400/400 [01:26<00:00,  4.62it/s]

Loss at epoch 5 is 6.328703097105026



100%|██████████| 400/400 [01:27<00:00,  4.57it/s]

Loss at epoch 6 is 6.313792891502381



100%|██████████| 400/400 [01:26<00:00,  4.65it/s]

Loss at epoch 7 is 6.302347882986068



100%|██████████| 400/400 [01:26<00:00,  4.60it/s]

Loss at epoch 8 is 6.289307940006256



100%|██████████| 400/400 [01:26<00:00,  4.63it/s]

Loss at epoch 9 is 6.283633580207825



100%|██████████| 400/400 [01:27<00:00,  4.59it/s]

Loss at epoch 10 is 6.2703089427948



100%|██████████| 400/400 [01:27<00:00,  4.58it/s]

Loss at epoch 11 is 6.263007271289825



100%|██████████| 400/400 [01:26<00:00,  4.65it/s]

Loss at epoch 12 is 6.2581951820850374



100%|██████████| 400/400 [01:27<00:00,  4.60it/s]

Loss at epoch 13 is 6.254954314231872



100%|██████████| 400/400 [01:26<00:00,  4.64it/s]

Loss at epoch 14 is 6.248858327865601



100%|██████████| 400/400 [01:26<00:00,  4.60it/s]

Loss at epoch 15 is 6.245929700136185



100%|██████████| 400/400 [01:27<00:00,  4.60it/s]

Loss at epoch 16 is 6.236321294307709



100%|██████████| 400/400 [01:25<00:00,  4.67it/s]

Loss at epoch 17 is 6.232666176557541



100%|██████████| 400/400 [01:26<00:00,  4.65it/s]

Loss at epoch 18 is 6.228970054388046



100%|██████████| 400/400 [01:25<00:00,  4.70it/s]

Loss at epoch 19 is 6.2273547959327695



100%|██████████| 400/400 [01:25<00:00,  4.67it/s]

Loss at epoch 20 is 6.223515694141388



100%|██████████| 400/400 [01:23<00:00,  4.77it/s]

Loss at epoch 21 is 6.21463249206543



100%|██████████| 400/400 [01:24<00:00,  4.72it/s]

Loss at epoch 22 is 6.214496537446975



100%|██████████| 400/400 [01:24<00:00,  4.75it/s]

Loss at epoch 23 is 6.217131797075272



100%|██████████| 400/400 [01:23<00:00,  4.77it/s]

Loss at epoch 24 is 6.208888179063797



100%|██████████| 400/400 [01:23<00:00,  4.80it/s]

Loss at epoch 25 is 6.20715482711792



100%|██████████| 400/400 [01:20<00:00,  4.96it/s]

Loss at epoch 26 is 6.207216184139252



100%|██████████| 400/400 [01:22<00:00,  4.87it/s]

Loss at epoch 27 is 6.202262964248657



100%|██████████| 400/400 [01:21<00:00,  4.90it/s]

Loss at epoch 28 is 6.197941930294037



100%|██████████| 400/400 [01:21<00:00,  4.88it/s]

Loss at epoch 29 is 6.193656015396118



100%|██████████| 400/400 [01:20<00:00,  4.97it/s]

Loss at epoch 30 is 6.190692994594574



100%|██████████| 400/400 [01:22<00:00,  4.87it/s]

Loss at epoch 31 is 6.184669544696808



100%|██████████| 400/400 [01:23<00:00,  4.82it/s]

Loss at epoch 32 is 6.187751567363739



100%|██████████| 400/400 [01:21<00:00,  4.88it/s]

Loss at epoch 33 is 6.186635009050369



100%|██████████| 400/400 [01:22<00:00,  4.84it/s]

Loss at epoch 34 is 6.179054121971131



100%|██████████| 400/400 [01:20<00:00,  4.99it/s]

Loss at epoch 35 is 6.174796094894409



100%|██████████| 400/400 [01:23<00:00,  4.79it/s]

Loss at epoch 36 is 6.170912672281265



100%|██████████| 400/400 [01:25<00:00,  4.66it/s]

Loss at epoch 37 is 6.173562723398208



100%|██████████| 400/400 [01:24<00:00,  4.76it/s]

Loss at epoch 38 is 6.171831804513931



100%|██████████| 400/400 [01:24<00:00,  4.71it/s]

Loss at epoch 39 is 6.165448198318481



100%|██████████| 400/400 [01:22<00:00,  4.82it/s]

Loss at epoch 40 is 6.165814943313599



100%|██████████| 400/400 [01:25<00:00,  4.69it/s]

Loss at epoch 41 is 6.166445529460907



100%|██████████| 400/400 [01:24<00:00,  4.74it/s]

Loss at epoch 42 is 6.158789964914322



100%|██████████| 400/400 [01:23<00:00,  4.82it/s]

Loss at epoch 43 is 6.158680576086044



100%|██████████| 400/400 [01:23<00:00,  4.77it/s]

Loss at epoch 44 is 6.1561285078525545



100%|██████████| 400/400 [01:23<00:00,  4.81it/s]

Loss at epoch 45 is 6.152451269626617



100%|██████████| 400/400 [01:22<00:00,  4.83it/s]

Loss at epoch 46 is 6.1536513662338255



100%|██████████| 400/400 [01:24<00:00,  4.76it/s]

Loss at epoch 47 is 6.148523621559143



100%|██████████| 400/400 [01:23<00:00,  4.81it/s]

Loss at epoch 48 is 6.152449351549149



100%|██████████| 400/400 [01:23<00:00,  4.78it/s]

Loss at epoch 49 is 6.150812653303146
Finished Training





In [18]:
save_model_location = '/scratch/vvb238/models/'
os.makedirs(save_model_location, exist_ok=True)
torch.save(net.state_dict(), os.path.join(save_model_location, "resnet34.pth"))

In [19]:
evalset = CustomDataset(root='../input/supervised-dataset/supervised_dataset', split="val", transform=eval_transform)
evalloader = torch.utils.data.DataLoader(evalset, batch_size=256, shuffle=False, num_workers=2)

In [None]:
# net = model.cuda()

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in evalloader:
        images, labels = data

        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f"Team {team_id}: {team_name} Accuracy: {(100 * correct / total):.2f}%")