In [None]:
import json
from PIL import Image

import torch
from torchvision import transforms

from efficientnet_pytorch import EfficientNet

import torch
from torch import nn
from torch.nn import functional as F
# A memory-efficient implementation of Swish function
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)
    
class Headmodel(nn.Module):
    def __init__(self,student_efficientnet=EfficientNet.from_pretrained('efficientnet-b0')):
        super().__init__()
        self._conv_stem = student_efficientnet._conv_stem
        self._bn0 = student_efficientnet._bn0
        self._swish = MemoryEfficientSwish()
        torch.nn.Conv2d(1, 6, 5)
        self._blocks = student_efficientnet._blocks[:2]
        
    def set_swish(self, memory_efficient=True):
        """Sets swish function as memory efficient (for training) or standard (for export).

        Args:
            memory_efficient (bool): Whether to use memory-efficient version of swish.
        """
        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
        for block in self._blocks:
            block.set_swish(memory_efficient)
            
    def extract_features(self, inputs):
        """use convolution layer to extract feature .

        Args:
            inputs (tensor): Input tensor.

        Returns:
            Output of the final convolution
            layer in the efficientnet model.
        """
        # Stem
        x = self._swish(self._bn0(self._conv_stem(inputs)))

        # Blocks
        for idx, block in enumerate(self._blocks):
            print(idx, block)
            print(block.forward(x).shape)
            drop_connect_rate = 0.2
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(self._blocks)  # scale drop connect_rate
            x = block(x, drop_connect_rate=drop_connect_rate)
        # x = self.connect(x)
        return x

net = Headmodel()
# head_output = net.extract_features(student_img)
print(net)