In [None]:
from dask.distributed import Client

scheduler_address = "tcp://10.108.56.79:8786"
client = Client(scheduler_address)
print(client)

In [None]:
from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd

dataset = load_dataset("Bingsu/Cat_and_Dog")

In [None]:
import torch
import torch.nn as nn

# Load the model - ResNet18 in this example
model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
model = model.to('cpu')  # Ensure model is on CPU for simplicity

In [None]:
import torchvision
model_future = client.scatter(model, broadcast=True) # send the model to each worker

In [None]:
import torch.optim as optim
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

def train(model, data_loader, epochs, lr):
    device = 'cpu'

    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        with tqdm(data_loader, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
            total_loss = 0.0
            for images, labels in pbar:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                pbar.set_postfix({'loss': loss.item()})

            print(f'Epoch {epoch+1}, Average Loss: {total_loss / len(data_loader)}')
    return model

In [None]:
from dask.distributed import Client, wait
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import DataLoader, Subset, Dataset as TorchDataset
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import chain


# Define a custom dataset class
class CustomImageDataset(TorchDataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['file']
        label = self.df.iloc[idx]['label']
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label
    
# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Convert dataset to CustomImageDataset
train_df = pd.DataFrame(dataset['train'])
test_df = pd.DataFrame(dataset['test'])
train_dataset = CustomImageDataset(train_df, transform=transform)
test_dataset = CustomImageDataset(test_df, transform=transform)

# Create DataLoader with batch size 32
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Training loop with batch dispatching and model aggregation
epochs = 7
learning_rate = 0.001
num_workers = 4
for epoch in range(epochs):
    batch_iter = iter(train_loader)

    while True:
        # Prepare batches for each worker
        batch_groups = [list(chain.from_iterable([next(batch_iter, None) for _ in range(1)])) for _ in range(num_workers)]
        if None in batch_groups:
            break  # End of DataLoader

        # Submit training tasks
        futures = [client.submit(train, model_future, batch_group, 1, learning_rate) for batch_group in batch_groups]

        # Wait for all tasks to complete
        results = client.gather(futures)

        # Synchronize models by summing weights
        state_dicts, losses = zip(*results)
        updated_state_dict = {key: sum(d[key] for d in state_dicts)}

        # Update the model with the averaged state dict
        model.load_state_dict(updated_state_dict)

        # Scatter the updated model back to the workers
        model_future = client.scatter(model, broadcast=True)

trained_model = client.gather(model_future)[0]

In [None]:
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy

test_accuracy = test_model(trained_model, test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}")

In [None]:
client.close()