In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
cd /content/drive/MyDrive/Dissertation/

/content/drive/MyDrive/Dissertation


In [None]:
!pip install -r requirements.txt

In [16]:
# Google Colab Unit Tests for Multimodal ViT Models
# Run this in a Colab cell after importing your model classes

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')

class TestViTForFER(unittest.TestCase):
    """Unit tests for ViTForFER model"""

    def setUp(self):
        """Set up test fixtures before each test method."""
        self.batch_size = 2
        self.num_classes = 7
        self.pixel_values = torch.randn(self.batch_size, 3, 224, 224)
        self.labels = torch.randint(0, self.num_classes, (self.batch_size,))

    def test_initialization_default(self):
        """Test default initialization"""
        model = ViTForFER()
        self.assertEqual(model.num_classes, 7)
        self.assertEqual(model.dropout_rate, 0.1)
        self.assertFalse(model.freeze_backbone)

    def test_initialization_custom(self):
        """Test custom initialization"""
        model = ViTForFER(
            num_classes=5,
            dropout_rate=0.2,
            freeze_backbone=True
        )
        self.assertEqual(model.num_classes, 5)
        self.assertEqual(model.dropout_rate, 0.2)
        self.assertTrue(model.freeze_backbone)

    def test_config_setup(self):
        """Test model configuration"""
        model = ViTForFER()
        self.assertEqual(len(model.config.id2label), 7)
        self.assertIn("happy", model.config.id2label.values())
        self.assertIn("angry", model.config.id2label.values())


    def test_forward_pass_with_labels(self):
        """Test forward pass with labels"""
        model = ViTForFER()
        outputs = model(self.pixel_values, self.labels)

        self.assertIn('logits', outputs)
        self.assertIsNotNone(outputs['loss'])
        self.assertIsInstance(outputs['loss'], torch.Tensor)
        self.assertGreaterEqual(outputs['loss'].item(), 0)

    def test_feature_extraction(self):
        """Test feature extraction"""
        model = ViTForFER()
        features = model.get_features(self.pixel_values)

        expected_shape = (self.batch_size, model.config.hidden_size)
        self.assertEqual(features.shape, expected_shape)

    def test_attention_weights(self):
        """Test attention weight extraction"""
        model = ViTForFER()
        attention_weights = model.get_attention_weights(self.pixel_values)

        self.assertIsNotNone(attention_weights)
        self.assertEqual(len(attention_weights.shape), 4)
        self.assertEqual(attention_weights.shape[0], self.batch_size)

    def test_freeze_unfreeze(self):
        """Test backbone freezing and unfreezing"""
        model = ViTForFER(freeze_backbone=True)

        # Check frozen state
        backbone_frozen = all(not p.requires_grad for p in model.vit.vit.parameters())
        self.assertTrue(backbone_frozen)

        # Unfreeze and test
        model.unfreeze_backbone()
        backbone_unfrozen = all(p.requires_grad for p in model.vit.vit.parameters())
        self.assertTrue(backbone_unfrozen)
        self.assertFalse(model.freeze_backbone)

    def test_model_info(self):
        """Test model information extraction"""
        model = ViTForFER(num_classes=5, dropout_rate=0.15)
        info = model.get_model_info()

        required_keys = [
            'model_name', 'num_classes', 'total_parameters', 'trainable_parameters',
            'freeze_backbone', 'dropout_rate', 'image_size', 'patch_size',
            'hidden_size', 'num_attention_heads', 'num_hidden_layers'
        ]

        for key in required_keys:
            self.assertIn(key, info)

        self.assertEqual(info['num_classes'], 5)
        self.assertEqual(info['dropout_rate'], 0.15)
        self.assertGreater(info['total_parameters'], 0)
        self.assertLessEqual(info['trainable_parameters'], info['total_parameters'])


class TestEarlyFusionViT(unittest.TestCase):
    """Unit tests for EarlyFusionViT model"""

    def setUp(self):
        """Set up test fixtures"""
        self.batch_size = 2
        self.rgb_images = torch.randn(self.batch_size, 3, 224, 224)
        self.thermal_images = torch.randn(self.batch_size, 3, 224, 224)
        self.labels = torch.randint(0, 7, (self.batch_size,))

    def test_initialization_concat(self):
        """Test concat fusion initialization"""
        model = EarlyFusionViT(fusion_type="concat")
        self.assertEqual(model.input_channels, 6)
        self.assertEqual(model.fusion_type, "concat")

    def test_initialization_add(self):
        """Test add fusion initialization"""
        model = EarlyFusionViT(fusion_type="add")
        self.assertEqual(model.input_channels, 3)
        self.assertEqual(model.fusion_type, "add")

    def test_forward_pass_concat(self):
        """Test forward pass with concat fusion"""
        model = EarlyFusionViT(fusion_type="concat")
        outputs = model(self.rgb_images, self.thermal_images, self.labels)

        self.assertIn('logits', outputs)
        self.assertEqual(outputs['logits'].shape, (self.batch_size, 7))
        self.assertIsNotNone(outputs['loss'])

    def test_forward_pass_add(self):
        """Test forward pass with add fusion"""
        model = EarlyFusionViT(fusion_type="add")
        outputs = model(self.rgb_images, self.thermal_images, self.labels)

        self.assertIn('logits', outputs)
        self.assertEqual(outputs['logits'].shape, (self.batch_size, 7))
        self.assertIsNotNone(outputs['loss'])

    def test_freeze_functionality(self):
        """Test freeze/unfreeze functionality"""
        model = EarlyFusionViT(freeze_backbone=True)

        # Check frozen
        backbone_frozen = all(not p.requires_grad for p in model.vit.parameters())
        self.assertTrue(backbone_frozen)

        # Unfreeze
        model.unfreeze_backbone()
        backbone_unfrozen = all(p.requires_grad for p in model.vit.parameters())
        self.assertTrue(backbone_unfrozen)


class TestLateFusionViT(unittest.TestCase):
    """Unit tests for LateFusionViT model"""

    def setUp(self):
        """Set up test fixtures"""
        self.batch_size = 2
        self.rgb_images = torch.randn(self.batch_size, 3, 224, 224)
        self.thermal_images = torch.randn(self.batch_size, 3, 224, 224)
        self.labels = torch.randint(0, 7, (self.batch_size,))

    def test_initialization_feature_fusion(self):
        """Test feature-level fusion initialization"""
        model = LateFusionViT(fusion_type="concat", fusion_layer="feature")
        self.assertTrue(hasattr(model, 'classifier'))
        self.assertTrue(hasattr(model, 'rgb_vit'))
        self.assertTrue(hasattr(model, 'thermal_vit'))

    def test_initialization_prediction_fusion(self):
        """Test prediction-level fusion initialization"""
        model = LateFusionViT(fusion_type="concat", fusion_layer="prediction")
        self.assertTrue(hasattr(model, 'rgb_classifier'))
        self.assertTrue(hasattr(model, 'thermal_classifier'))

    def test_initialization_attention_fusion(self):
        """Test attention fusion initialization"""
        model = LateFusionViT(fusion_type="attention", fusion_layer="feature")
        self.assertTrue(hasattr(model, 'attention_fusion'))

    def test_forward_pass_feature_level(self):
        """Test forward pass with feature-level fusion"""
        fusion_configs = [
            {"fusion_type": "concat", "fusion_layer": "feature"},
            {"fusion_type": "add", "fusion_layer": "feature"},
            {"fusion_type": "attention", "fusion_layer": "feature"},
        ]

        for config in fusion_configs:
            with self.subTest(config=config):
                model = LateFusionViT(**config)
                outputs = model(self.rgb_images, self.thermal_images, self.labels)

                self.assertIn('logits', outputs)
                self.assertEqual(outputs['logits'].shape, (self.batch_size, 7))
                self.assertIsNotNone(outputs['loss'])
                self.assertIn('rgb_features', outputs)
                self.assertIn('thermal_features', outputs)

    def test_forward_pass_prediction_level(self):
        """Test forward pass with prediction-level fusion"""
        fusion_configs = [
            {"fusion_type": "concat", "fusion_layer": "prediction"},
            {"fusion_type": "add", "fusion_layer": "prediction"},
            {"fusion_type": "attention", "fusion_layer": "prediction"},
        ]

        for config in fusion_configs:
            with self.subTest(config=config):
                model = LateFusionViT(**config)
                outputs = model(self.rgb_images, self.thermal_images, self.labels)

                self.assertIn('logits', outputs)
                self.assertEqual(outputs['logits'].shape, (self.batch_size, 7))
                self.assertIsNotNone(outputs['loss'])


class TestModelFactory(unittest.TestCase):
    """Unit tests for create_multimodal_vit_model function"""

    def test_create_rgb_model(self):
        """Test RGB model creation"""
        model = create_multimodal_vit_model(mode='rgb')
        self.assertIsInstance(model, ViTForFER)

    def test_create_thermal_model(self):
        """Test thermal model creation"""
        model = create_multimodal_vit_model(mode='thermal')
        self.assertIsInstance(model, ViTForFER)

    def test_create_early_fusion_model(self):
        """Test early fusion model creation"""
        model = create_multimodal_vit_model(
            mode='combined',
            fusion_strategy='early',
            fusion_type='concat'
        )
        self.assertIsInstance(model, EarlyFusionViT)

    def test_create_late_fusion_model(self):
        """Test late fusion model creation"""
        model = create_multimodal_vit_model(
            mode='combined',
            fusion_strategy='late',
            fusion_type='concat',
            fusion_layer='feature'
        )
        self.assertIsInstance(model, LateFusionViT)

    def test_invalid_mode(self):
        """Test invalid mode raises error"""
        with self.assertRaises(ValueError):
            create_multimodal_vit_model(mode='invalid')


class TestOptimizerAndScheduler(unittest.TestCase):
    """Unit tests for optimizer and scheduler creation"""

    def test_adamw_optimizer(self):
        """Test AdamW optimizer creation"""
        model = ViTForFER()
        optimizer, scheduler = get_optimizer_and_scheduler(
            model=model,
            learning_rate=1e-4,
            optimizer_type="adamw"
        )

        self.assertIsInstance(optimizer, torch.optim.AdamW)
        self.assertEqual(len(optimizer.param_groups), 2)

    def test_sgd_optimizer(self):
        """Test SGD optimizer creation"""
        model = ViTForFER()
        optimizer, scheduler = get_optimizer_and_scheduler(
            model=model,
            optimizer_type="sgd"
        )

        self.assertIsInstance(optimizer, torch.optim.SGD)
        self.assertEqual(len(optimizer.param_groups), 2)


class TestGradientFlow(unittest.TestCase):
    """Unit tests for gradient flow"""

    def test_gradient_flow_unfrozen(self):
        """Test gradient flow with unfrozen backbone"""
        model = ViTForFER()
        batch_size = 2
        pixel_values = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
        labels = torch.randint(0, 7, (batch_size,))

        outputs = model(pixel_values, labels)
        loss = outputs['loss']
        loss.backward()

        # Check gradients exist
        has_gradients = False
        for param in model.parameters():
            if param.requires_grad and param.grad is not None:
                has_gradients = True
                break

        self.assertTrue(has_gradients)

    def test_gradient_flow_frozen(self):
        """Test gradient flow with frozen backbone"""
        model = ViTForFER(freeze_backbone=True)
        batch_size = 2
        pixel_values = torch.randn(batch_size, 3, 224, 224)
        labels = torch.randint(0, 7, (batch_size,))

        outputs = model(pixel_values, labels)
        loss = outputs['loss']
        loss.backward()

        # Check only classifier has gradients
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.assertIsNotNone(param.grad)
                self.assertIn('classifier', name)


class TestInputValidation(unittest.TestCase):
    """Unit tests for input validation"""

    def test_different_batch_sizes(self):
        """Test different batch sizes"""
        model = ViTForFER()

        for batch_size in [1, 2, 4, 8]:
            with self.subTest(batch_size=batch_size):
                pixel_values = torch.randn(batch_size, 3, 224, 224)
                outputs = model(pixel_values)
                self.assertEqual(outputs['logits'].shape[0], batch_size)

    def test_multimodal_consistency(self):
        """Test multimodal input consistency"""
        batch_size = 2
        rgb_images = torch.randn(batch_size, 3, 224, 224)
        thermal_images = torch.randn(batch_size, 3, 224, 224)
        labels = torch.randint(0, 7, (batch_size,))

        early_concat = EarlyFusionViT(fusion_type="concat")
        early_add = EarlyFusionViT(fusion_type="add")

        outputs_concat = early_concat(rgb_images, thermal_images, labels)
        outputs_add = early_add(rgb_images, thermal_images, labels)

        # Different fusion should produce different outputs
        self.assertFalse(torch.allclose(outputs_concat['logits'], outputs_add['logits'], atol=1e-6))

        # But same shape
        self.assertEqual(outputs_concat['logits'].shape, outputs_add['logits'].shape)

In [7]:
# Colab-friendly test runner
def run_tests_in_colab():
    """
    Run all tests in Google Colab with nice formatting
    """
    print("Starting Multimodal ViT Unit Tests")
    print("="*60)

    # Capture test output
    test_output = StringIO()
    runner = unittest.TextTestRunner(stream=test_output, verbosity=2)

    # Create test suite
    test_classes = [
        TestViTForFER,
        TestEarlyFusionViT,
        TestLateFusionViT,
        TestModelFactory,
        TestOptimizerAndScheduler,
        TestGradientFlow,
        TestInputValidation
    ]

    total_tests = 0
    total_failures = 0
    total_errors = 0

    for test_class in test_classes:
        print(f"\nRunning {test_class.__name__}...")
        suite = unittest.TestLoader().loadTestsFromTestCase(test_class)
        result = runner.run(suite)

        total_tests += result.testsRun
        total_failures += len(result.failures)
        total_errors += len(result.errors)

        if result.failures:
            print(f"{len(result.failures)} failures")
            for test, traceback in result.failures:
                print(f"   FAIL: {test}")

        if result.errors:
            print(f"💥 {len(result.errors)} errors")
            for test, traceback in result.errors:
                print(f"   ERROR: {test}")

        if not result.failures and not result.errors:
            print(f" All tests passed!")

    print("\n" + "="*60)
    print(f" FINAL RESULTS:")
    print(f"   Total Tests: {total_tests}")
    print(f"   Passed: {total_tests - total_failures - total_errors}")
    print(f"   Failed: {total_failures}")
    print(f"   Errors: {total_errors}")

    if total_failures == 0 and total_errors == 0:
        print(" ==== ALL TESTS PASSED! ====")
    else:
        print(f"⚠️  {total_failures + total_errors} TESTS FAILED")

    return total_failures == 0 and total_errors == 0

In [8]:
# Quick test function for individual testing
def quick_test():
    """Run a quick smoke test"""
    print("🔥 Running Quick Smoke Test...")

    try:
        # Test basic model creation
        model = ViTForFER()
        print("✅ ViTForFER created successfully")

        # Test basic forward pass
        x = torch.randn(1, 3, 224, 224)
        outputs = model(x)
        print(f"✅ Forward pass successful, output shape: {outputs['logits'].shape}")

        # Test multimodal model
        multimodal_model = create_multimodal_vit_model(mode='combined', fusion_strategy='early')
        print("✅ Multimodal model created successfully")

        rgb = torch.randn(1, 3, 224, 224)
        thermal = torch.randn(1, 3, 224, 224)
        multimodal_outputs = multimodal_model(rgb, thermal)
        print(f"✅ Multimodal forward pass successful, output shape: {multimodal_outputs['logits'].shape}")

        print("🎉 Quick test completed successfully!")
        return True

    except Exception as e:
        print(f"❌ Quick test failed: {str(e)}")
        return False

In [9]:
# Usage instructions for Colab
if __name__ == "__main__":
    print("""
    📚 How to use these tests in Google Colab:

    1. First, make sure your model classes are imported:
       from model import create_multimodal_vit_model, get_optimizer_and_scheduler, ViTForFER, EarlyFusionViT, LateFusionViT

    2. Run a quick test:
       quick_test()

    3. Run all unit tests:
       run_tests_in_colab()

    4. Run specific test class:
       suite = unittest.TestLoader().loadTestsFromTestCase(TestViTForFER)
       runner = unittest.TextTestRunner(verbosity=2)
       runner.run(suite)
    """)

    # Uncomment to run tests automatically
    # quick_test()
    run_tests_in_colab()


    📚 How to use these tests in Google Colab:

    1. First, make sure your model classes are imported:
       from model import create_multimodal_vit_model, get_optimizer_and_scheduler, ViTForFER, EarlyFusionViT, LateFusionViT

    2. Run a quick test:
       quick_test()

    3. Run all unit tests:
       run_tests_in_colab()

    4. Run specific test class:
       suite = unittest.TestLoader().loadTestsFromTestCase(TestViTForFER)
       runner = unittest.TextTestRunner(verbosity=2)
       runner.run(suite)
    
Starting Multimodal ViT Unit Tests

Running TestViTForFER...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image proc

 All tests passed!

Running TestEarlyFusionViT...


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized because the shapes did not match:
- embeddings.patch_embeddings.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 6, 16, 16]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized because the shapes did not match:
- embeddings.patch_embeddings.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 6, 16, 16]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224-in2

 All tests passed!

Running TestLateFusionViT...
 All tests passed!

Running TestModelFactory...


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized because the shapes did not match:
- embeddings.patch_embeddings.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 6, 16, 16]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of V

 All tests passed!

Running TestOptimizerAndScheduler...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image proc

 All tests passed!

Running TestGradientFlow...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image proc

 All tests passed!

Running TestInputValidation...


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized because the shapes did not match:
- embeddings.patch_embeddings.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 6, 16, 16]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


 All tests passed!

 FINAL RESULTS:
   Total Tests: 29
   Passed: 29
   Failed: 0
   Errors: 0
 ==== ALL TESTS PASSED! ====
