In [1]:
import os
import random
from PIL import Image
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from fastai.vision.models import xresnet

In [2]:
team_id = 15
team_name = "loSSLess"
email_address = "vvb238@nyu.edu"

In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """

        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        self.num_images = len(os.listdir(self.image_dir))

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [4]:
train_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [5]:
trainset = CustomDataset(root='/dataset', split="train", transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

In [6]:
import torchvision

In [7]:
class SwaVRes34(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder=torch.nn.Sequential(*(list(xresnet.xresnet18(pretrained=False).children()))[:-1])
        checkpoint = torch.load(f"./encoder_xres_v2.pth",map_location=lambda storage, loc: storage)
        self.encoder.load_state_dict(checkpoint)
        self.classifier=nn.Linear(in_features=512,out_features=800)
    
    def forward(self,x):
        rep=self.encoder(x).view(x.shape[0],-1)
        y_hat=self.classifier(rep)
        return y_hat

In [8]:
model=SwaVRes34()

In [9]:
net = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

In [10]:
print('Start Training')
numOfBatches = len(trainset)/trainloader.batch_size

for epoch in range(50):
    net.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(trainloader), total=int(numOfBatches)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()

        outputs = net(inputs)
#         break
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
#         if i % 10 == 9:    # print every 10 mini-batches
#             print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
#             running_loss = 0.0
#     break
    print("Loss at epoch", epoch, "is", running_loss/numOfBatches)
    scheduler.step(running_loss)

print('Finished Training')

Start Training


100%|██████████| 400/400 [00:39<00:00, 10.05it/s]

Loss at epoch 0 is 6.6930344200134275



100%|██████████| 400/400 [00:26<00:00, 15.07it/s]

Loss at epoch 1 is 6.232588925361633



100%|██████████| 400/400 [00:26<00:00, 15.10it/s]

Loss at epoch 2 is 5.863012660741806



100%|██████████| 400/400 [00:26<00:00, 14.96it/s]

Loss at epoch 3 is 5.565988391637802



100%|██████████| 400/400 [00:27<00:00, 14.29it/s]

Loss at epoch 4 is 5.317149314880371



100%|██████████| 400/400 [00:28<00:00, 14.17it/s]

Loss at epoch 5 is 5.115214765071869



100%|██████████| 400/400 [00:28<00:00, 14.08it/s]

Loss at epoch 6 is 4.953958375453949



100%|██████████| 400/400 [00:27<00:00, 14.47it/s]

Loss at epoch 7 is 4.8138041305542



100%|██████████| 400/400 [00:27<00:00, 14.73it/s]

Loss at epoch 8 is 4.692055580615997



100%|██████████| 400/400 [00:28<00:00, 14.10it/s]

Loss at epoch 9 is 4.601152571439743



100%|██████████| 400/400 [00:28<00:00, 14.18it/s]

Loss at epoch 10 is 4.512380722761154



100%|██████████| 400/400 [00:27<00:00, 14.71it/s]

Loss at epoch 11 is 4.436258012652397



100%|██████████| 400/400 [00:28<00:00, 14.11it/s]

Loss at epoch 12 is 4.372679911255837



100%|██████████| 400/400 [00:27<00:00, 14.47it/s]

Loss at epoch 13 is 4.309593679904938



100%|██████████| 400/400 [00:28<00:00, 14.20it/s]

Loss at epoch 14 is 4.2651438140869145



100%|██████████| 400/400 [00:28<00:00, 14.10it/s]

Loss at epoch 15 is 4.213372098803521



100%|██████████| 400/400 [00:28<00:00, 14.10it/s]

Loss at epoch 16 is 4.163716551661492



100%|██████████| 400/400 [00:28<00:00, 14.22it/s]

Loss at epoch 17 is 4.120588554143906



100%|██████████| 400/400 [00:26<00:00, 14.97it/s]

Loss at epoch 18 is 4.095967952013016



100%|██████████| 400/400 [00:28<00:00, 14.24it/s]

Loss at epoch 19 is 4.044753779768944



100%|██████████| 400/400 [00:28<00:00, 14.10it/s]

Loss at epoch 20 is 4.016096643209457



100%|██████████| 400/400 [00:27<00:00, 14.38it/s]

Loss at epoch 21 is 3.988048440217972



100%|██████████| 400/400 [00:28<00:00, 14.21it/s]

Loss at epoch 22 is 3.951083384156227



100%|██████████| 400/400 [00:27<00:00, 14.69it/s]

Loss at epoch 23 is 3.9261232513189315



100%|██████████| 400/400 [00:28<00:00, 14.14it/s]

Loss at epoch 24 is 3.8980646473169327



100%|██████████| 400/400 [00:27<00:00, 14.37it/s]

Loss at epoch 25 is 3.8766831022500994



100%|██████████| 400/400 [00:28<00:00, 14.11it/s]

Loss at epoch 26 is 3.844563713669777



100%|██████████| 400/400 [00:27<00:00, 14.29it/s]

Loss at epoch 27 is 3.8213777244091034



100%|██████████| 400/400 [00:27<00:00, 14.62it/s]

Loss at epoch 28 is 3.7939903283119203



100%|██████████| 400/400 [00:27<00:00, 14.65it/s]

Loss at epoch 29 is 3.781476872563362



100%|██████████| 400/400 [00:27<00:00, 14.33it/s]

Loss at epoch 30 is 3.7507964104413984



100%|██████████| 400/400 [00:26<00:00, 15.00it/s]


Loss at epoch 31 is 3.735007284283638


100%|██████████| 400/400 [00:27<00:00, 14.49it/s]

Loss at epoch 32 is 3.712365919947624



100%|██████████| 400/400 [00:26<00:00, 15.13it/s]


Loss at epoch 33 is 3.691937118768692


100%|██████████| 400/400 [00:26<00:00, 15.10it/s]

Loss at epoch 34 is 3.6693595492839814



100%|██████████| 400/400 [00:26<00:00, 15.08it/s]

Loss at epoch 35 is 3.6579600262641905



100%|██████████| 400/400 [00:28<00:00, 14.16it/s]

Loss at epoch 36 is 3.630110812187195



100%|██████████| 400/400 [00:28<00:00, 14.13it/s]

Loss at epoch 37 is 3.6230029392242433



100%|██████████| 400/400 [00:27<00:00, 14.34it/s]

Loss at epoch 38 is 3.6041781163215636



100%|██████████| 400/400 [00:28<00:00, 14.10it/s]

Loss at epoch 39 is 3.5868296349048614



100%|██████████| 400/400 [00:28<00:00, 14.11it/s]

Loss at epoch 40 is 3.57201711833477



100%|██████████| 400/400 [00:28<00:00, 14.07it/s]

Loss at epoch 41 is 3.541877045035362



100%|██████████| 400/400 [00:28<00:00, 14.24it/s]

Loss at epoch 42 is 3.537510045170784



100%|██████████| 400/400 [00:28<00:00, 14.17it/s]

Loss at epoch 43 is 3.527223652005196



100%|██████████| 400/400 [00:27<00:00, 14.56it/s]

Loss at epoch 44 is 3.491117052435875



100%|██████████| 400/400 [00:27<00:00, 14.32it/s]

Loss at epoch 45 is 3.48798057615757



100%|██████████| 400/400 [00:26<00:00, 14.98it/s]

Loss at epoch 46 is 3.479614748954773



100%|██████████| 400/400 [00:26<00:00, 15.07it/s]

Loss at epoch 47 is 3.4601959371566773



100%|██████████| 400/400 [00:28<00:00, 14.13it/s]

Loss at epoch 48 is 3.4572765928506852



100%|██████████| 400/400 [00:28<00:00, 14.26it/s]

Loss at epoch 49 is 3.430015920996666
Finished Training





In [11]:
save_model_location = './'
os.makedirs(save_model_location, exist_ok=True)
torch.save(net.state_dict(), os.path.join(save_model_location, "xx_resnet18_v2.pth"))

In [12]:
evalset = CustomDataset(root='/dataset', split="val", transform=eval_transform)
evalloader = torch.utils.data.DataLoader(evalset, batch_size=256, shuffle=False, num_workers=2)

In [13]:
# net = model.cuda()

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in evalloader:
        images, labels = data

        images = images.cuda()
        labels = labels.cuda()

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f"Team {team_id}: {team_name} Accuracy: {(100 * correct / total):.2f}%")

Team 15: loSSLess Accuracy: 20.90%
