<a href="https://colab.research.google.com/github/scaairesearch/da_demo/blob/main/Gradio_C1_C2_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gradio --quiet

In [2]:
import gradio as gr
import os
from PIL import Image
from torchvision import datasets,transforms
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from collections import OrderedDict
import pandas as pd
import io
import base64

In [3]:
# checking the mounted drive and mounting if not done
if not os.path.exists('/content/gdrive'):
  from google.colab import drive
  drive.mount('/content/gdrive')
else:
    print("Google Drive is already mounted.")

Mounted at /content/gdrive


In [4]:
list_c1 = torch.load('/content/gdrive/MyDrive/da_demo/cv/models/26_06/list_mnist_m_non_dann_misclassified_dann_classified.pt')

In [5]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        imgs, labels, image_names = self.data[idx]
        return imgs, labels, image_names

dataset_c1 = CustomDataset(list_c1)

In [6]:
# Create a dataloader with the filtered dataset
dataloader_c1 = torch.utils.data.DataLoader(dataset_c1, batch_size=10, shuffle=True)

In [7]:
transform_to_pil  = transforms.ToPILImage()

def get_images():
    images, labels,image_names = next(iter(dataloader_c1))
    pil_images = [transform_to_pil(image) for image in images]
    return pil_images, labels.tolist()



In [8]:
list_c2 = torch.load('/content/gdrive/MyDrive/da_demo/cv/models/26_06/list_mnist_m_non_dann_misclassified_dann_misclassified.pt')
dataset_c2 = CustomDataset(list_c2)
dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
def get_images_2():
    images, labels,image_names = next(iter(dataloader_c2))
    pil_images = [transform_to_pil(image) for image in images]
    return pil_images, labels.tolist()

In [9]:
# next(iter(dataloader_c1))

In [10]:
def get_device():
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print("Device Selected:", device)
    return device

device = get_device()


Device Selected: cpu


In [11]:

class GradientReversalFn(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class Network(nn.Module):
    def __init__(self, num_classes = 10):
        super(Network, self).__init__()  # Initialize the parent class

        drop_out_value = 0.1

        #---------------------Feature Extractor Network------------------------#
        self.feature_extractor  = nn.Sequential(
            # Input Block
            nn.Conv2d(3, 16, 3, bias=False),  # In: 3x28x28, Out: 16x26x26, RF: 3x3, Stride: 1
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(drop_out_value),

            # Conv Block 2
            nn.Conv2d(16, 16, 3, bias=False),  # In: 16x26x26, Out: 16x24x24, RF: 5x5, Stride: 1
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(drop_out_value),

            # Conv Block 3
            nn.Conv2d(16, 16, 3, bias=False),  # In: 16x24x24, Out: 16x22x22, RF: 7x7, Stride: 1
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(drop_out_value),

            # Transition Block 1
            nn.MaxPool2d(kernel_size=2, stride=2),  # In: 16x22x22, Out: 16x11x11, RF: 8x8, Stride: 2

            # Conv Block 4
            nn.Conv2d(16, 16, 3, bias=False),  # In: 16x11x11, Out: 16x9x9, RF: 12x12, Stride: 1
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(drop_out_value),

            # Conv Block 5
            nn.Conv2d(16, 32, 3, bias=False),  # In: 16x9x9, Out: 32x7x7, RF: 16x16, Stride: 1
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(drop_out_value),

            # Output Block
            nn.Conv2d(32, 64, 1, bias=False),  # In: 32x7x7, Out: 64x7x7, RF: 16x16, Stride: 1

            # Global Average Pooling
            nn.AvgPool2d(7)  # In: 64x7x7, Out: 64x1x1, RF: 16x16, Stride: 7
        )

        #---------------------Class Classifier Network------------------------#
        self.class_classifier = nn.Sequential(nn.ReLU(),
                                        nn.Dropout(p=drop_out_value),
                                        nn.Linear(64,50),
                                        nn.BatchNorm1d(50), # added batch norm to improve accuracy
                                        nn.ReLU(),
                                        nn.Dropout(p=drop_out_value),
                                        nn.Linear(50,num_classes))

        #---------------------Label Classifier Network------------------------#
        self.domain_classifier = nn.Sequential(nn.ReLU(),
                                        nn.Dropout(p=drop_out_value),
                                        nn.Linear(64,50),
                                        nn.BatchNorm1d(50), # added batch norm to improve accuracy
                                        nn.ReLU(),
                                        nn.Dropout(p=drop_out_value),
                                        nn.Linear(50,2))
    def forward(self, input_data, alpha = 1.0):
      if input_data.data.shape[1] == 1:
        input_data = input_data.expand(input_data.data.shape[0], 3, img_size, img_size)

      input_data = self.feature_extractor(input_data)

      features = input_data.view(input_data.size(0), -1)  # Flatten the output for fully connected layer

      reverse_features = GradientReversalFn.apply(features, alpha)
      class_output = self.class_classifier(features)
      domain_output = self.domain_classifier(reverse_features)

      return class_output, domain_output, features

In [None]:
## NON DANN
# Instantiate the model (make sure it has the same architecture)
loaded_model_non_dann = Network()
loaded_model_non_dann = loaded_model_non_dann.to(device)
# Load the saved state dictionary
loaded_model_non_dann.load_state_dict(torch.load('/content/gdrive/MyDrive/da_demo/cv/models/26_06/non_dann_26_06.pt', map_location=device), strict=False)
loaded_model_non_dann.eval()

In [None]:
##  DANN
# Instantiate the model (make sure it has the same architecture)
loaded_model_dann = Network()
loaded_model_dann = loaded_model_dann.to(device)
# Load the saved state dictionary
loaded_model_dann.load_state_dict(torch.load('/content/gdrive/MyDrive/da_demo/cv/models/26_06/dann_26_06.pt', map_location=device), strict=False)
loaded_model_dann.eval()

In [14]:
img_size = 28 # for mnist
cpu_batch_size = 10
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

In [16]:
def classify_image_both(image):
  target_test_transforms = transforms.Compose([
                                      transforms.Resize(img_size),
                                      transforms.ToTensor(),# converts to tesnor
                                      transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
                                      ])


  target_transformed_image = target_test_transforms(image)
  image_tensor = target_transformed_image.to(device).unsqueeze(0)

  list_confidences = []
  for model in [loaded_model_non_dann, loaded_model_dann]:
    model.eval()
    logits,_,_ = model(image_tensor)
    output = F.softmax(logits.view(-1), dim = -1)

    confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))]
    confidences.sort(key=lambda x: x[1], reverse=True)
    confidences = OrderedDict(confidences[:3])
    label = torch.argmax(output).item()
    list_confidences.append(confidences)


  return list_confidences[0],list_confidences[1]

In [17]:
### SOURCE DATA - MNIST

# Test Phase transformations
test_transforms = transforms.Compose([
                                      #  transforms.Resize(img_size),
                                       transforms.ToTensor(),# converts to tesnor
                                      #  transforms.Normalize((0.1307,), (0.3081,))
                                       ])
transform_to_pil  = transforms.ToPILImage()
test = datasets.MNIST('./data',
                      train=False,
                      download=True,
                      transform=test_transforms)

dataloader_args = dict(shuffle=True, batch_size=cpu_batch_size)

mnist_loader = torch.utils.data.DataLoader(
    dataset = test,
    **dataloader_args
)

def get_mnist_images():
    images, labels = next(iter(mnist_loader))
    pil_images = [transform_to_pil(image) for image in images]
    return pil_images, labels.tolist()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 20363420.84it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 640509.37it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 5587267.19it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3033523.69it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [18]:
splits = {'train': 'data/train-00000-of-00001-571b6b1e2c195186.parquet', 'test': 'data/test-00000-of-00001-ba3ad971b105ff65.parquet'}
df = pd.read_parquet("hf://datasets/Mike0307/MNIST-M/" + splits["test"])

class MNIST_M(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None):
      self.dataframe = dataframe
      self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Get image and label from dataframe
        img_data = self.dataframe.iloc[idx]['image']['bytes']
        label = self.dataframe.iloc[idx]['label']
        img_path = self.dataframe.iloc[idx]['image']['path']

        # Decode image data (assuming it's base64 encoded)
        img = Image.open(io.BytesIO(img_data))


        # Apply transformations if any
        if self.transform:
            img = self.transform(img)

        return img, label,img_path


# Test Phase transformations
target_test_transforms = transforms.Compose([
                                       transforms.Resize(img_size),
                                       transforms.ToTensor(),# converts to tesnor
                                       transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
                                       ])


transform_to_pil  = transforms.ToPILImage()


# Create dataset
target_test_dataset = MNIST_M(dataframe=df, transform=target_test_transforms)
target_test_dataloader = torch.utils.data.DataLoader(target_test_dataset, batch_size=cpu_batch_size, shuffle=True)
def get_mnist_m_images():
    images, labels,image_names = next(iter(target_test_dataloader))
    pil_images = [transform_to_pil(image) for image in images]
    return pil_images, labels.tolist()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [19]:
mnist_images, mnist_labels = get_mnist_images()
mnist_m_images,mnist_m_labels = get_mnist_m_images()

In [20]:
def classify_image_inference(image):
  # print(image.mode)
  image_transforms = None
  if image.mode == "L":
    # image = image.convert("RGB")
    source = 'MNIST'
    image_transforms = transforms.Compose([
                                  transforms.Resize(img_size),
                                  transforms.ToTensor(),# converts to tesnor
                                  transforms.Normalize((0.1307,), (0.3081,))
                                  ])
  else:
    source = 'MNIST-M'
    image_transforms = transforms.Compose([
                            transforms.Resize(img_size),
                            transforms.ToTensor(),# converts to tesnor
                            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
                            ])

  transformed_image = image_transforms(image)
  image_tensor = transformed_image.to(device).unsqueeze(0)

  list_confidences = []
  for model in [loaded_model_non_dann, loaded_model_dann]:
    model.eval()
    logits,_,_ = model(image_tensor)
    output = F.softmax(logits.view(-1), dim = -1)

    confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))]
    confidences.sort(key=lambda x: x[1], reverse=True)
    confidences = OrderedDict(confidences[:3])
    label = torch.argmax(output).item()
    list_confidences.append(confidences)


  return list_confidences[0],list_confidences[1]



In [21]:
def display_image():
    # Load the image from a local file
    image = Image.open("/content/gdrive/MyDrive/da_demo/mnist-m.JPG")
    return image

In [None]:
with gr.Blocks() as demo:
  with gr.Tab("Introduction"):
      gr.Markdown("## Domain Adaptation in Deep Networks - Demonstration")
      with gr.Row():
          with gr.Column():
              image_output = gr.Image(value=display_image(), label = "source and target",height = 256, width = 256, show_label = True)
      gr.Markdown(
          '''
          Source - MNIST
          ------
          - The MNIST database (Modified National Institute of Standards and Technology database) is a large collection of handwritten digits.
          - It has a training set of 60,000 examples, and a test set of 10,000 examples.
          - 28 x 28 size
          - 1 channel

          '''
          )
      gr.Markdown(
          '''
          Target - MNIST-M
          -------
          - MNIST-M is created by combining MNIST digits with the patches randomly extracted from color photos of BSDS500 as their background.
          - It contains 59,001 training and 90,001 test images.
          - 28 x 28 size
          - 3 channels
          '''
      )

      gr.Markdown(
          '''
          Please click on the tabs, for more functionality
          -------
          - Inferencing on NonDANN and DANN : Infer MNIST or MNISTM on both Models
          - Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify : Curated list which misclassify on NON DANN but classifies well on NonDANN
          - Case 2: MNIST_M_Both_Misclassify : Curated list which misclassify Both on NON DANN and DANN
          '''
      )



################################################
  with gr.Tab("Inferencing on NonDANN and DANN"):
    with gr.Row():
      with gr.Column():
        input_image_classify_mnist = gr.Image(label="Classify MNIST Digit", type = "pil", height = 256, width = 256, image_mode = 'L')
        button_classify_mnist = gr.Button("Submit to Classify MNIST Image", visible = True, size ='sm')
      with gr.Column():
        with gr.Row():
          label_classify_mnist_non_dann = gr.Label(label = "NON DANN Predicted MNIST label", num_top_classes=2, visible = True)
        with gr.Row():
          label_classify_mnist_dann = gr.Label(label = "DANN Predicted MNIST label", num_top_classes=2, visible = True)
    with gr.Row():
      gr.Examples( [img.convert("L") for img in mnist_images],
                  inputs=[input_image_classify_mnist], label = "Select an example MNIST Image")

    with gr.Row():
      with gr.Column():
        input_image_classify_mnist_m = gr.Image(label="Classify MNIST M Digit", type = "pil", height = 256, width = 256, image_mode = 'RGB')
        button_classify_mnist_m = gr.Button("Submit to Classify MNIST M Image", visible = True, size ='sm')
      with gr.Column():
        with gr.Row():
          label_classify_mnist_m_non_dann = gr.Label(label = "NON DANN Predicted MNIST M label", num_top_classes=2, visible = True)
        with gr.Row():
          label_classify_mnist_m_dann = gr.Label(label = "DANN Predicted MNIST M label", num_top_classes=2, visible = True)
    with gr.Row():
      gr.Examples( [img.convert("RGB") for img in mnist_m_images],
                  inputs=[input_image_classify_mnist_m], label = "Select an example MNIST M Image")
    with gr.Row():
      gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels]}')

    button_classify_mnist.click(fn=classify_image_inference,
                          inputs=[input_image_classify_mnist],
                          outputs=[label_classify_mnist_non_dann, label_classify_mnist_dann])

    button_classify_mnist_m.click(fn=classify_image_inference,
                          inputs=[input_image_classify_mnist_m],
                          outputs=[label_classify_mnist_m_non_dann, label_classify_mnist_m_dann])


  ######################
  with gr.Tab("Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify"):
    # with gr.Row():
    #   radio_model = gr.Radio(["Baseline (Non-DANN)", "DANN"],
    #                            label="Select the model you want to use.",
    #                            value="Baseline (Non-DANN)",  # Set default value
    #                            scale=2)
    with gr.Row():
      with gr.Column():
        input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256)
        button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm')

      with gr.Column():
        with gr.Row():
          label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True)
        with gr.Row():
          label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True)

    mnist_m_images_1,mnist_m_labels_1 = get_images()

    with gr.Row():
      gr.Examples(mnist_m_images_1,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working

    with gr.Row():
      gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_1]}')

    button_classify_both.click(fn=classify_image_both,
                          inputs=[input_image_classify_both],
                          outputs=[label_classify_non_dann,label_classify_dann])


########################################################################

  with gr.Tab("Case 2 - Show both: MNIST_M_Both_Misclassify"):

    with gr.Row():
      with gr.Column():
        input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256)
        button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm')

      with gr.Column():
        with gr.Row():
          label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True)
        with gr.Row():
          label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True)

    mnist_m_images_2,mnist_m_labels_2 = get_images_2()

    with gr.Row():
      gr.Examples(mnist_m_images_2,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working

    with gr.Row():
      gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_2]}')


    button_classify_both.click(fn=classify_image_both,
                          inputs=[input_image_classify_both],
                          outputs=[label_classify_non_dann,label_classify_dann])


demo.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://fdba0c0185237471e3.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
