In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import torchvision.datasets as datasets

torch.manual_seed(21)

<torch._C.Generator at 0x11bac6650>

In [2]:
if torch.backends.mps.is_available():
  device = torch.device("mps")
elif torch.cuda.is_available():
  device = torch.device("cuda:0")
else:
  device = torch.device("cpu")
print(device)

mps


In [3]:
# Define the transformations: resize to 224x224, convert to tensor, and normalize.
transform = transforms.Compose([
	transforms.Resize((224, 224)),
	transforms.ToTensor(),
	transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet mean and standard deviation
])

# Set the root folder path where the "1. Healthy" and "3. WSSV" subfolders are located.
data_root = "ShrimpDiseaseImageBD An Image Dataset for Computer Vision-Based Detection of Shrimp Diseases in Bangladesh/Root/Raw Images"

dataset = datasets.ImageFolder(root=data_root, transform=transform)

train_size = int(0.8 * len(dataset))
train_dataset, test_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"# of train images: {len(train_dataset)}")
print(f"# of test images: {len(test_dataset)}")

for images, labels in train_loader:
	print(f"Shape of one batch: {images.shape}")
	print(labels) # 0 = Healthy, 1 = WSSV
	break

# of train images: 584
# of test images: 147
Shape of one batch: torch.Size([32, 3, 224, 224])
tensor([1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0,
        1, 1, 1, 0, 0, 1, 1, 0])


In [4]:
class Block(nn.Module):
	def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):
		super(Block, self).__init__()
		self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
		self.bn1 = nn.BatchNorm2d(out_channels)
		self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
		self.bn2 = nn.BatchNorm2d(out_channels)
		self.relu = nn.ReLU()
		self.identity_downsample = identity_downsample

	def forward(self, x):
		identity = x
		x = self.conv1(x)
		x = self.bn1(x)
		x = self.relu(x)
		x = self.conv2(x)
		x = self.bn2(x)
		if self.identity_downsample is not None:
			identity = self.identity_downsample(identity)
		x += identity
		x = self.relu(x)
		return x

In [5]:
class ResNet_18(nn.Module):
	def __init__(self, image_channels, num_classes):
		super(ResNet_18, self).__init__()
		self.in_channels = 64
		self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
		self.bn1 = nn.BatchNorm2d(64)
		self.relu = nn.ReLU()
		self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

		# resnet layers
		self.layer1 = self.__make_layer(64, 64, stride=1)
		self.layer2 = self.__make_layer(64, 128, stride=2)
		self.layer3 = self.__make_layer(128, 256, stride=2)
		self.layer4 = self.__make_layer(256, 512, stride=2)

		self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
		self.fc = nn.Linear(512, num_classes)

	def __make_layer(self, in_channels, out_channels, stride):			
		identity_downsample = None
		if stride != 1:
			identity_downsample = self.identity_downsample(in_channels, out_channels)

		return nn.Sequential(
			Block(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), 
			Block(out_channels, out_channels)
		)

	def forward(self, x):
		x = self.conv1(x)
		x = self.bn1(x)
		x = self.relu(x)
		x = self.maxpool(x)

		x = self.layer1(x)
		x = self.layer2(x)
		x = self.layer3(x)
		x = self.layer4(x)

		x = self.avgpool(x)
		x = x.view(x.shape[0], -1)
		x = self.fc(x)
		return x 

	def identity_downsample(self, in_channels, out_channels):
		return nn.Sequential(
			nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), 
			nn.BatchNorm2d(out_channels)
		)

In [6]:
model = ResNet_18(1, 10)

In [7]:
#count trainable parameters of the model
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

12556426


In [8]:
#move the model to the device
model.to(device)
next(model.parameters()).is_cuda

False