In [1]:
## Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

### Networks basics

In [2]:

class MultiSupervisionNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes1, num_classes2):
        super(MultiSupervisionNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2_1 = nn.Linear(hidden_size, num_classes1) ## Output for supervision signal 1
        self.fc2_2 = nn.Linear(hidden_size, num_classes2) ## Output for supervision signal 2

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        out1 = self.fc2_1(x) ## Output for the first classification task
        out2 = self.fc2_2(x) ## Output for the second classification task
        return out1, out2

## Example usage
input_size = 784
hidden_size = 256
num_classes1 = 10
num_classes2 = 5

model = MultiSupervisionNetwork(input_size, hidden_size, num_classes1, num_classes2)
input_tensor = torch.randn(1, input_size)  ## Example input (batch size 1)

output1, output2 = model(input_tensor)

print("Output 1 shape:", output1.shape)
print("Output 2 shape:", output2.shape)

# Define loss functions (example: cross-entropy)
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss()

## Example target values (replace with your actual targets)
target1 = torch.tensor([2]) ## Example target for the first task
target2 = torch.tensor([0]) ## Example target for the second task

## Calculate losses
loss1 = criterion1(output1, target1)
loss2 = criterion2(output2, target2)

## Combine losses (example: weighted sum)
total_loss = 0.7 * loss1 + 0.3 * loss2

print(f"Loss 1: {loss1}")
print(f"Loss 2: {loss2}")
print(f"Total loss: {total_loss}")

## Backpropagation and optimization would follow here



Output 1 shape: torch.Size([1, 10])
Output 2 shape: torch.Size([1, 5])
Loss 1: 2.350739002227783
Loss 2: 1.4947116374969482
Total loss: 2.093930721282959


In [None]:
## Hyperparameters
input_size = 3 * 224 * 224 ## Example input size (adjust based on your image size and preprocessing)
hidden_size = 512
num_classes1 = 102  ## Number of classes for Flower-102 (or adjust)
num_classes2 = 10 ## Example second supervision (e.g., coarse-grained categories)
learning_rate = 0.001
num_epochs = 10
batch_size = 32

## Data loading and preprocessing (Flower-102)
## Ensure you have the dataset downloaded correctly. If not, see torchvision documentation.
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ## Example normalization
])


## Create data loaders
train_dataset = datasets.Flowers102(root='./data', split='train', download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


## Model, Loss functions, and Optimizer
model = MultiSupervisionNetwork(input_size, hidden_size, num_classes1, num_classes2)
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

## Training loop
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        ## Reshape the image data
        data = data.view(-1, input_size) # Flatten the image data

        ## Create dummy second target (replace with your second supervision signal)
        target2 = torch.randint(0, num_classes2, (target.size(0),))


        ## Forward pass
        output1, output2 = model(data)
        loss1 = criterion1(output1, target)
        loss2 = criterion2(output2, target2)  ## Replace target2 with actual second supervision
        total_loss = 0.8 * loss1 + 0.2 * loss2 ## Example weight combination

        ## Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Loss1: {loss1.item():.4f}, Loss2: {loss2.item():.4f}, Total Loss: {total_loss.item():.4f}")