In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm import tqdm     
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
ds = load_dataset("Bingsu/Cat_and_Dog")

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating train split: 100%|██████████| 8000/8000 [00:00<00:00, 40192.55 examples/s]
Generating test split: 100%|██████████| 2000/2000 [00:00<00:00, 35024.33 examples/s]


In [4]:
class ImageDataset(Dataset):
    def __init__(self, data, compose, type):
        super().__init__()
        self.data = data
        self.compose = compose
        self.type = type

    def __len__(self):
        return len(self.data[self.type])  # Fix dataset length

    def __getitem__(self, index):
        image = self.data[self.type][index]['image']
        label = torch.tensor(self.data[self.type][index]['labels'])
        image = self.compose(image)

        return image, label

In [5]:
compose = T.Compose([
    T.Resize((64, 64)),  # Randomly crop the image to 224x224 pixels
    T.RandomHorizontalFlip(p=0.5),  # Randomly flip the image horizontally with a probability of 0.5
    T.RandomVerticalFlip(p=0.5),  # Randomly flip the image vertically with a probability of 0.5
    T.RandomRotation(degrees=30),  # Randomly rotate the image by up to 30 degrees
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Randomly adjust brightness, contrast, saturation, and hue
    T.ToTensor(),  # Convert the image to a PyTorch tensor
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,stride = stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
    def forward(self,x):
        identity = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        
        return self.relu(out)


    def __init__(self, in_channels, out_channels, stride = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,kernel_size = 3, stride=stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act = nn.GELU()
        
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size = 3, stride = 1, padding=1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)



In [7]:
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleResNet, self).__init__()

        # Initial Convolution Layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.GELU()

        # ResNet Blocks
        self.res_block1 = ResNetBlock(64, 128, stride=2)  # Downsample
        self.res_block2 = ResNetBlock(128, 256, stride=2) # Downsample
        self.res_block3 = ResNetBlock(256, 512, stride=2) # Downsample
        self.res_block4 = ResNetBlock(512,512, stride=2)

        #remove when low perf

        # Global Average Pooling and Fully Connected Layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # Outputs (batch, 512, 1, 1)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.res_block3(x)

        x = self.avgpool(x)  
        x = torch.flatten(x, 1)  # Flatten for FC layer
        x = self.fc(x)
        return x

In [8]:
model  = SimpleResNet(num_classes=1)

In [9]:
train_dataset = ImageDataset(ds, compose, 'train')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=16)

test_dataset = ImageDataset(ds, compose, 'test')
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=1)

In [10]:
image, label = train_dataset[0]
image.shape

torch.Size([3, 64, 64])

In [11]:
@torch.no_grad()
def evaluate_model(test):
	model.eval()
	print("Running Validation")
	correct = 0
	for X,y in tqdm(test):
		X = X.to(device)
		y = y.view(-1, 1).float().to(device)
		out = torch.sigmoid(model(X))
		prediction = (out >= 0.5).int()
		correct += prediction.item() == y.item()
	accuracy = (correct/len(test))*100
	print(f"Accuracy : {accuracy}")
	return accuracy

In [12]:
loss_fn = nn.BCEWithLogitsLoss()  # Change to BCEWithLogitsLoss
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)



In [13]:
max_acc = 0

In [16]:

EPOCHS = 50
model.load_state_dict(torch.load('modelBest.pth'))
model.to(device)
for epoch in range(EPOCHS):
    sum_loss = 0
    for X, y in tqdm(train_dataloader):
        model.train()
        X = X.to(device)
        y = y.view(-1, 1).float().to(device)  # Ensure correct shape

        # Forward
        out = model(X)
        loss = loss_fn(out, y)
        
        sum_loss = sum_loss + loss.item()
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=5.0)
        optimizer.step()
    avg = sum_loss/len(train_dataloader)
    scheduler.step(avg)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg:.4f}")
    acc = evaluate_model(test_dataloader)
    if(acc>max_acc):
        torch.save(model.state_dict(), 'model.pth')
        max_acc = acc

    


  model.load_state_dict(torch.load('modelBest.pth'))
100%|██████████| 500/500 [00:28<00:00, 17.85it/s]


Epoch [1/50], Loss: 0.2295
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 192.66it/s]


Accuracy : 86.1


100%|██████████| 500/500 [00:28<00:00, 17.83it/s]


Epoch [2/50], Loss: 0.2079
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 209.79it/s]


Accuracy : 86.45


100%|██████████| 500/500 [00:27<00:00, 18.01it/s]


Epoch [3/50], Loss: 0.2042
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 201.80it/s]


Accuracy : 88.05


100%|██████████| 500/500 [00:27<00:00, 18.45it/s]


Epoch [4/50], Loss: 0.1950
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 199.40it/s]


Accuracy : 87.45


100%|██████████| 500/500 [00:26<00:00, 18.52it/s]


Epoch [5/50], Loss: 0.1981
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 192.29it/s]


Accuracy : 88.4


100%|██████████| 500/500 [00:27<00:00, 18.48it/s]


Epoch [6/50], Loss: 0.1989
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 204.74it/s]


Accuracy : 88.6


100%|██████████| 500/500 [00:26<00:00, 18.54it/s]


Epoch [7/50], Loss: 0.1755
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 200.97it/s]


Accuracy : 89.45


100%|██████████| 500/500 [00:27<00:00, 18.26it/s]


Epoch [8/50], Loss: 0.1836
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 206.50it/s]


Accuracy : 88.14999999999999


100%|██████████| 500/500 [00:26<00:00, 18.53it/s]


Epoch [9/50], Loss: 0.1771
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 187.86it/s]


Accuracy : 88.2


100%|██████████| 500/500 [00:30<00:00, 16.33it/s]


Epoch [10/50], Loss: 0.1804
Running Validation


100%|██████████| 2000/2000 [00:11<00:00, 176.07it/s]


Accuracy : 88.44999999999999


100%|██████████| 500/500 [00:30<00:00, 16.62it/s]


Epoch [11/50], Loss: 0.1784
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 183.57it/s]


Accuracy : 88.94999999999999


100%|██████████| 500/500 [00:28<00:00, 17.84it/s]


Epoch [12/50], Loss: 0.1769
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 201.72it/s]


Accuracy : 88.2


100%|██████████| 500/500 [00:27<00:00, 18.49it/s]


Epoch [13/50], Loss: 0.1727
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 192.03it/s]


Accuracy : 89.1


100%|██████████| 500/500 [00:28<00:00, 17.68it/s]


Epoch [14/50], Loss: 0.1675
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 206.18it/s]


Accuracy : 87.9


100%|██████████| 500/500 [00:27<00:00, 17.93it/s]


Epoch [15/50], Loss: 0.1650
Running Validation


100%|██████████| 2000/2000 [00:11<00:00, 179.50it/s]


Accuracy : 88.4


100%|██████████| 500/500 [00:27<00:00, 18.25it/s]


Epoch [16/50], Loss: 0.1594
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 206.32it/s]


Accuracy : 88.05


100%|██████████| 500/500 [00:26<00:00, 18.56it/s]


Epoch [17/50], Loss: 0.1617
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 194.86it/s]


Accuracy : 89.35


100%|██████████| 500/500 [00:27<00:00, 17.88it/s]


Epoch [18/50], Loss: 0.1667
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 198.95it/s]


Accuracy : 87.7


100%|██████████| 500/500 [00:27<00:00, 18.35it/s]


Epoch [19/50], Loss: 0.1591
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 195.53it/s]


Accuracy : 88.8


100%|██████████| 500/500 [00:27<00:00, 18.24it/s]


Epoch [20/50], Loss: 0.1607
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 198.91it/s]


Accuracy : 88.7


100%|██████████| 500/500 [00:27<00:00, 18.01it/s]


Epoch [21/50], Loss: 0.1537
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 196.48it/s]


Accuracy : 88.0


100%|██████████| 500/500 [00:27<00:00, 18.21it/s]


Epoch [22/50], Loss: 0.1547
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 211.85it/s]


Accuracy : 88.05


100%|██████████| 500/500 [00:27<00:00, 18.43it/s]


Epoch [23/50], Loss: 0.1577
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 212.11it/s]


Accuracy : 88.6


100%|██████████| 500/500 [00:27<00:00, 18.26it/s]


Epoch [24/50], Loss: 0.1452
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 214.30it/s]


Accuracy : 87.55


100%|██████████| 500/500 [00:27<00:00, 18.06it/s]


Epoch [25/50], Loss: 0.1641
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 188.19it/s]


Accuracy : 88.2


100%|██████████| 500/500 [00:29<00:00, 17.21it/s]


Epoch [26/50], Loss: 0.1535
Running Validation


100%|██████████| 2000/2000 [00:11<00:00, 177.98it/s]


Accuracy : 89.75


100%|██████████| 500/500 [00:28<00:00, 17.25it/s]


Epoch [27/50], Loss: 0.1454
Running Validation


100%|██████████| 2000/2000 [00:11<00:00, 179.61it/s]


Accuracy : 88.85


100%|██████████| 500/500 [00:28<00:00, 17.29it/s]


Epoch [28/50], Loss: 0.1395
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 199.04it/s]


Accuracy : 88.75


100%|██████████| 500/500 [00:27<00:00, 18.04it/s]


Epoch [29/50], Loss: 0.1378
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 194.42it/s]


Accuracy : 88.6


100%|██████████| 500/500 [00:28<00:00, 17.67it/s]


Epoch [30/50], Loss: 0.1373
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 204.24it/s]


Accuracy : 88.3


100%|██████████| 500/500 [00:27<00:00, 18.19it/s]


Epoch [31/50], Loss: 0.1361
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 187.75it/s]


Accuracy : 88.5


100%|██████████| 500/500 [00:29<00:00, 17.10it/s]


Epoch [32/50], Loss: 0.1402
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 191.03it/s]


Accuracy : 88.6


100%|██████████| 500/500 [00:28<00:00, 17.84it/s]


Epoch [33/50], Loss: 0.1324
Running Validation


100%|██████████| 2000/2000 [00:09<00:00, 204.27it/s]


Accuracy : 88.75


100%|██████████| 500/500 [00:27<00:00, 18.27it/s]


Epoch [34/50], Loss: 0.1445
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 199.75it/s]


Accuracy : 88.75


100%|██████████| 500/500 [00:27<00:00, 17.87it/s]


Epoch [35/50], Loss: 0.1353
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 190.69it/s]


Accuracy : 89.3


100%|██████████| 500/500 [00:28<00:00, 17.43it/s]


Epoch [36/50], Loss: 0.1326
Running Validation


100%|██████████| 2000/2000 [00:11<00:00, 181.59it/s]


Accuracy : 90.05


100%|██████████| 500/500 [00:44<00:00, 11.21it/s]


Epoch [37/50], Loss: 0.1261
Running Validation


100%|██████████| 2000/2000 [00:19<00:00, 102.79it/s]


Accuracy : 89.05


100%|██████████| 500/500 [00:47<00:00, 10.56it/s]


Epoch [38/50], Loss: 0.1350
Running Validation


100%|██████████| 2000/2000 [00:19<00:00, 104.00it/s]


Accuracy : 89.35


100%|██████████| 500/500 [00:48<00:00, 10.38it/s]


Epoch [39/50], Loss: 0.1252
Running Validation


100%|██████████| 2000/2000 [00:12<00:00, 162.66it/s]


Accuracy : 89.85


100%|██████████| 500/500 [00:28<00:00, 17.69it/s]


Epoch [40/50], Loss: 0.1226
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 187.86it/s]


Accuracy : 89.75


100%|██████████| 500/500 [00:28<00:00, 17.72it/s]


Epoch [41/50], Loss: 0.1255
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 189.98it/s]


Accuracy : 89.7


100%|██████████| 500/500 [00:28<00:00, 17.45it/s]


Epoch [42/50], Loss: 0.1340
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 187.74it/s]


Accuracy : 89.55


100%|██████████| 500/500 [00:28<00:00, 17.68it/s]


Epoch [43/50], Loss: 0.1117
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 187.56it/s]


Accuracy : 89.45


100%|██████████| 500/500 [00:28<00:00, 17.71it/s]


Epoch [44/50], Loss: 0.1233
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 188.10it/s]


Accuracy : 89.55


100%|██████████| 500/500 [00:28<00:00, 17.80it/s]


Epoch [45/50], Loss: 0.1155
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 189.41it/s]


Accuracy : 89.9


100%|██████████| 500/500 [00:28<00:00, 17.76it/s]


Epoch [46/50], Loss: 0.1132
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 188.07it/s]


Accuracy : 89.05


100%|██████████| 500/500 [00:28<00:00, 17.58it/s]


Epoch [47/50], Loss: 0.1215
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 199.18it/s]


Accuracy : 89.4


100%|██████████| 500/500 [00:27<00:00, 17.95it/s]


Epoch [48/50], Loss: 0.1174
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 197.96it/s]


Accuracy : 89.7


100%|██████████| 500/500 [00:27<00:00, 18.17it/s]


Epoch [49/50], Loss: 0.1215
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 193.80it/s]


Accuracy : 88.55


100%|██████████| 500/500 [00:27<00:00, 17.87it/s]


Epoch [50/50], Loss: 0.0852
Running Validation


100%|██████████| 2000/2000 [00:10<00:00, 192.13it/s]

Accuracy : 90.05





In [15]:
comp = T.Compose([
    T.Resize((64, 64)),
    T.ToTensor(),  # Convert the image to a PyTorch tensor
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [16]:
img = comp(Image.open('doge.jpg').convert('RGB'))
img.shape

torch.Size([3, 64, 64])

In [17]:
model.eval()
out = model(img.unsqueeze(0).to(device))
out = F.sigmoid(out)
out

tensor([[0.9170]], device='cuda:0', grad_fn=<SigmoidBackward0>)