In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
# from torchvision import ImageDatasets
from PIL import Image
import cv2
from cvzone.PoseModule import PoseDetector
import matplotlib.pyplot as plt

class FashionDataset(Dataset):
    def __init__(self, person_dir, clothes_dir, transform=None):
        self.person_dir = person_dir
        self.clothes_dir = clothes_dir
        self.transform = transform
        self.person_images = os.listdir(person_dir)
        self.clothes_images = os.listdir(clothes_dir)

    def __len__(self):
        return min(len(self.person_images), len(self.clothes_images))

    def __getitem__(self, idx):
        person_img_path = os.path.join(self.person_dir, self.person_images[idx])
        clothes_img_path = os.path.join(self.clothes_dir, self.clothes_images[idx])
        
        person_image = Image.open(person_img_path).convert('RGB')
        clothes_image = Image.open(clothes_img_path).convert('RGB')
        
        if self.transform:
            person_image = self.transform(person_image)
            clothes_image = self.transform(clothes_image)
        
        return person_image, clothes_image

transform = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.ToTensor()
])

train_dataset = FashionDataset('ImageDataset/train/person', 'ImageDataset/train/cloth', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

test_dataset = FashionDataset('ImageDataset/test/person', 'ImageDataset/test/cloth', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

class SimpleUNet(nn.Module):
    def __init__(self):
        super(SimpleUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = SimpleUNet()

# Training loop
epochs = 5
learning_rate = 0.001
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for person_images, clothes_images in train_loader:
        optimizer.zero_grad()
        outputs = model(person_images)
        loss = criterion(outputs, person_images)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}')

# Evaluation loop
model.eval()
test_loss = 0.0

with torch.no_grad():
    for person_images, clothes_images in test_loader:
        outputs = model(person_images)
        loss = criterion(outputs, person_images)
        test_loss += loss.item()

print(f'Test Loss: {test_loss / len(test_loader):.4f}')

# Pose detection with cvzone
pose_detector = PoseDetector()

def get_keypoints(image_path):
    image = cv2.imread(image_path)
    keypoints, _ = pose_detector.findPose(image)
    return keypoints

def align_clothes(person_image, clothes_image, keypoints):
    # Example alignment logic
    # For simplicity, we'll use a placeholder
    return clothes_image

def test_virtual_try_on_with_pose(person_image_path, clothes_image_path):
    model.eval()
    
    person_image = Image.open(person_image_path).convert('RGB')
    clothes_image = Image.open(clothes_image_path).convert('RGB')
    
    keypoints = get_keypoints(person_image_path)
    if keypoints is None:
        print("No keypoints detected.")
        return
    
    aligned_clothes_image = align_clothes(person_image, clothes_image, keypoints)
    
    person_tensor = transform(person_image).unsqueeze(0)
    clothes_tensor = transform(aligned_clothes_image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(person_tensor)
    
    output_image = transforms.ToPILImage()(output.squeeze())
    plt.imshow(output_image)
    plt.show()

# Example usage
test_virtual_try_on_with_pose('ImageDataset/test/person/example_person.jpg', 'ImageDataset/test/cloth/example_cloth.jpg')



: 

In [5]:
%pip install cvzone

Collecting cvzone
  Downloading cvzone-1.6.1.tar.gz (25 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: cvzone
  Building wheel for cvzone (pyproject.toml) ... [?25ldone
[?25h  Created wheel for cvzone: filename=cvzone-1.6.1-py3-none-any.whl size=26295 sha256=5db4e6e1fa0b0096b3d6a1fe044f9521a174491dc059b9f678aca3a609d85474
  Stored in directory: /home/raimudit2003/.cache/pip/wheels/5d/21/e8/3147ae88d44e27f06e0175d337a7673c70fb957202cbbe2034
Successfully built cvzone
Installing collected packages: cvzone
Successfully installed cvzone-1.6.1
Note: you may need to restart the kernel to use updated packages.
