<a href="https://colab.research.google.com/github/thuanha-groove/Datasets/blob/master/clip_music_video_CPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==1.7.1 torchvision==0.8.2 ftfy tqdm regex kornia==0.5.1 textwrap3==0.9.2
!pip install boto3==1.17.72 botocore==1.20.72
!pip install clip-by-openai

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math
import json
import logging
import os
import shutil
import tempfile
from functools import wraps
from hashlib import sha256
import sys
import copy
import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm

WEIGHTS_NAME = 'pytorch_model.bin'
CONFIG_NAME = 'config.json'

try:
    from urllib.parse import urlparse
except ImportError:
    from urlparse import urlparse

try:
    from pathlib import Path
    PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
                                                   os.path.join(os.getcwd(), '.pytorch_pretrained_biggan')))
except (AttributeError, ImportError):
    PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
                                              os.path.join(os.getcwd(), '.pytorch_pretrained_biggan'))

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


def url_to_filename(url, etag=None):
    """
    Convert `url` into a hashed filename in a repeatable way.
    If `etag` is specified, append its hash to the url's, delimited
    by a period.
    """
    url_bytes = url.encode('utf-8')
    url_hash = sha256(url_bytes)
    filename = url_hash.hexdigest()

    if etag:
        etag_bytes = etag.encode('utf-8')
        etag_hash = sha256(etag_bytes)
        filename += '.' + etag_hash.hexdigest()

    return filename


def filename_to_url(filename, cache_dir=None):
    """
    Return the url and etag (which may be ``None``) stored for `filename`.
    Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    cache_path = os.path.join(cache_dir, filename)
    if not os.path.exists(cache_path):
        raise EnvironmentError("file {} not found".format(cache_path))

    meta_path = cache_path + '.json'
    if not os.path.exists(meta_path):
        raise EnvironmentError("file {} not found".format(meta_path))

    with open(meta_path, encoding="utf-8") as meta_file:
        metadata = json.load(meta_file)
    url = metadata['url']
    etag = metadata['etag']

    return url, etag


def cached_path(url_or_filename, cache_dir=None):
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
    if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
        url_or_filename = str(url_or_filename)
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    parsed = urlparse(url_or_filename)

    if parsed.scheme in ('http', 'https', 's3'):
        # URL, so get it from the cache (downloading if necessary)
        return get_from_cache(url_or_filename, cache_dir)
    elif os.path.exists(url_or_filename):
        # File, and it exists.
        return url_or_filename
    elif parsed.scheme == '':
        # File, but it doesn't exist.
        raise EnvironmentError("file {} not found".format(url_or_filename))
    else:
        # Something unknown
        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))


def split_s3_path(url):
    """Split a full s3 path into the bucket name and path."""
    parsed = urlparse(url)
    if not parsed.netloc or not parsed.path:
        raise ValueError("bad s3 path {}".format(url))
    bucket_name = parsed.netloc
    s3_path = parsed.path
    # Remove '/' at beginning of path.
    if s3_path.startswith("/"):
        s3_path = s3_path[1:]
    return bucket_name, s3_path


def s3_request(func):
    """
    Wrapper function for s3 requests in order to create more helpful error
    messages.
    """

    @wraps(func)
    def wrapper(url, *args, **kwargs):
        try:
            return func(url, *args, **kwargs)
        except ClientError as exc:
            if int(exc.response["Error"]["Code"]) == 404:
                raise EnvironmentError("file {} not found".format(url))
            else:
                raise

    return wrapper


@s3_request
def s3_etag(url):
    """Check ETag on S3 object."""
    s3_resource = boto3.resource("s3")
    bucket_name, s3_path = split_s3_path(url)
    s3_object = s3_resource.Object(bucket_name, s3_path)
    return s3_object.e_tag


@s3_request
def s3_get(url, temp_file):
    """Pull a file directly from S3."""
    s3_resource = boto3.resource("s3")
    bucket_name, s3_path = split_s3_path(url)
    s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def http_get(url, temp_file):
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit="B", total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk: # filter out keep-alive new chunks
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()


def get_from_cache(url, cache_dir=None):
    """
    Given a URL, look for the corresponding dataset in the local cache.
    If it's not there, download it. Then return the path to the cached file.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    # Get eTag to add to filename, if it exists.
    if url.startswith("s3://"):
        etag = s3_etag(url)
    else:
        response = requests.head(url, allow_redirects=True)
        if response.status_code != 200:
            raise IOError("HEAD request failed for url {} with status code {}"
                          .format(url, response.status_code))
        etag = response.headers.get("ETag")

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

    if not os.path.exists(cache_path):
        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with tempfile.NamedTemporaryFile() as temp_file:
            logger.info("%s not found in cache, downloading to %s", url, temp_file.name)

            # GET file object
            if url.startswith("s3://"):
                s3_get(url, temp_file)
            else:
                http_get(url, temp_file)

            # we are copying the file before closing it, so flush to avoid truncation
            temp_file.flush()
            # shutil.copyfileobj() starts at the current position, so go to the start
            temp_file.seek(0)

            logger.info("copying %s to cache at %s", temp_file.name, cache_path)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)

            logger.info("creating metadata file for %s", cache_path)
            meta = {'url': url, 'etag': etag}
            meta_path = cache_path + '.json'
            with open(meta_path, 'w', encoding="utf-8") as meta_file:
                json.dump(meta, meta_file)

            logger.info("removing temp file %s", temp_file.name)

    return cache_path


def read_set_from_file(filename):
    '''
    Extract a de-duped collection (set) of text from a file.
    Expected file format is one item per line.
    '''
    collection = set()
    with open(filename, 'r', encoding='utf-8') as file_:
        for line in file_:
            collection.add(line.rstrip())
    return collection


def get_file_extension(path, dot=True, lower=True):
    ext = os.path.splitext(path)[1]
    ext = ext if dot else ext[1:]
    return ext.lower() if lower else ext


PRETRAINED_MODEL_ARCHIVE_MAP = {
    'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
    'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
    'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
}

PRETRAINED_CONFIG_ARCHIVE_MAP = {
    'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
    'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
    'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
}

class BigGANConfig(object):
    """ Configuration class to store the configuration of a `BigGAN`. 
        Defaults are for the 128x128 model.
        layers tuple are (up-sample in the layer ?, input channels, output channels)
    """
    def __init__(self,
                 output_dim=128,
                 z_dim=128,
                 class_embed_dim=128,
                 channel_width=128,
                 num_classes=1000,
                 layers=[(False, 16, 16),
                         (True, 16, 16),
                         (False, 16, 16),
                         (True, 16, 8),
                         (False, 8, 8),
                         (True, 8, 4),
                         (False, 4, 4),
                         (True, 4, 2),
                         (False, 2, 2),
                         (True, 2, 1)],
                 attention_layer_position=8,
                 eps=1e-4,
                 n_stats=51):
        """Constructs BigGANConfig. """
        self.output_dim = output_dim
        self.z_dim = z_dim
        self.class_embed_dim = class_embed_dim
        self.channel_width = channel_width
        self.num_classes = num_classes
        self.layers = layers
        self.attention_layer_position = attention_layer_position
        self.eps = eps
        self.n_stats = n_stats

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
        config = BigGANConfig()
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BigGANConfig` from a json file of parameters."""
        with open(json_file, "r", encoding='utf-8') as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


def snconv2d(eps=1e-12, **kwargs):
    return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)

def snlinear(eps=1e-12, **kwargs):
    return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)

def sn_embedding(eps=1e-12, **kwargs):
    return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)

class SelfAttn(nn.Module):
    """ Self attention Layer"""
    def __init__(self, in_channels, eps=1e-12):
        super(SelfAttn, self).__init__()
        self.in_channels = in_channels
        self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
                                        kernel_size=1, bias=False, eps=eps)
        self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
                                      kernel_size=1, bias=False, eps=eps)
        self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
                                    kernel_size=1, bias=False, eps=eps)
        self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
                                         kernel_size=1, bias=False, eps=eps)
        self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
        self.softmax  = nn.Softmax(dim=-1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        _, ch, h, w = x.size()
        # Theta path
        theta = self.snconv1x1_theta(x)
        theta = theta.view(-1, ch//8, h*w)
        # Phi path
        phi = self.snconv1x1_phi(x)
        phi = self.maxpool(phi)
        phi = phi.view(-1, ch//8, h*w//4)
        # Attn map
        attn = torch.bmm(theta.permute(0, 2, 1), phi)
        attn = self.softmax(attn)
        # g path
        g = self.snconv1x1_g(x)
        g = self.maxpool(g)
        g = g.view(-1, ch//2, h*w//4)
        # Attn_g - o_conv
        attn_g = torch.bmm(g, attn.permute(0, 2, 1))
        attn_g = attn_g.view(-1, ch//2, h, w)
        attn_g = self.snconv1x1_o_conv(attn_g)
        # Out
        out = x + self.gamma*attn_g
        return out


class BigGANBatchNorm(nn.Module):
    """ This is a batch norm module that can handle conditional input and can be provided with pre-computed
        activation means and variances for various truncation parameters.
        We cannot just rely on torch.batch_norm since it cannot handle
        batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
        If you want to train this model you should add running means and variance computation logic.
    """
    def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
        super(BigGANBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.conditional = conditional

        # We use pre-computed statistics for n_stats values of truncation between 0 and 1
        self.register_buffer('running_means', torch.zeros(n_stats, num_features))
        self.register_buffer('running_vars', torch.ones(n_stats, num_features))
        self.step_size = 1.0 / (n_stats - 1)

        if conditional:
            assert condition_vector_dim is not None
            self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
            self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
        else:
            self.weight = torch.nn.Parameter(torch.Tensor(num_features))
            self.bias = torch.nn.Parameter(torch.Tensor(num_features))

    def forward(self, x, truncation, condition_vector=None):
        # Retreive pre-computed statistics associated to this truncation
        coef, start_idx = math.modf(truncation / self.step_size)
        start_idx = int(start_idx)
        if coef != 0.0:  # Interpolate
            running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
            running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
        else:
            running_mean = self.running_means[start_idx]
            running_var = self.running_vars[start_idx]

        if self.conditional:
            running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

            weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
            bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)

            out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
        else:
            out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
                               training=False, momentum=0.0, eps=self.eps)

        return out


class GenBlock(nn.Module):
    def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
                 n_stats=51, eps=1e-12):
        super(GenBlock, self).__init__()
        self.up_sample = up_sample
        self.drop_channels = (in_size != out_size)
        middle_size = in_size // reduction_factor

        self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)

        self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)

        self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)

        self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
        self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)

        self.relu = nn.ReLU()

    def forward(self, x, cond_vector, truncation):
        x0 = x

        x = self.bn_0(x, truncation, cond_vector)
        x = self.relu(x)
        x = self.conv_0(x)

        x = self.bn_1(x, truncation, cond_vector)
        x = self.relu(x)
        if self.up_sample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv_1(x)

        x = self.bn_2(x, truncation, cond_vector)
        x = self.relu(x)
        x = self.conv_2(x)

        x = self.bn_3(x, truncation, cond_vector)
        x = self.relu(x)
        x = self.conv_3(x)

        if self.drop_channels:
            new_channels = x0.shape[1] // 2
            x0 = x0[:, :new_channels, ...]
        if self.up_sample:
            x0 = F.interpolate(x0, scale_factor=2, mode='nearest')

        out = x + x0
        return out

class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        self.config = config
        ch = config.channel_width
        condition_vector_dim = config.z_dim * 2

        self.gen_z = snlinear(in_features=condition_vector_dim,
                              out_features=4 * 4 * 16 * ch, eps=config.eps)

        layers = []
        for i, layer in enumerate(config.layers):
            if i == config.attention_layer_position:
                layers.append(SelfAttn(ch*layer[1], eps=config.eps))
            layers.append(GenBlock(ch*layer[1],
                                   ch*layer[2],
                                   condition_vector_dim,
                                   up_sample=layer[0],
                                   n_stats=config.n_stats,
                                   eps=config.eps))
        self.layers = nn.ModuleList(layers)

        self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
        self.relu = nn.ReLU()
        self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
        self.tanh = nn.Tanh()

    def forward(self, cond_vector, truncation):
        z = self.gen_z(cond_vector[0].unsqueeze(0))

        # We use this conversion step to be able to use TF weights:
        # TF convention on shape is [batch, height, width, channels]
        # PT convention on shape is [batch, channels, height, width]
        z = z.view(-1, 4, 4, 16 * self.config.channel_width)
        z = z.permute(0, 3, 1, 2).contiguous()

        for i, layer in enumerate(self.layers):
            if isinstance(layer, GenBlock):
                z = layer(z, cond_vector[i+1].unsqueeze(0), truncation)
                # z = layer(z, cond_vector[].unsqueeze(0), truncation)
            else:
                z = layer(z)

        z = self.bn(z, truncation)
        z = self.relu(z)
        z = self.conv_to_rgb(z)
        z = z[:, :3, ...]
        z = self.tanh(z)
        return z

class BigGAN(nn.Module):
    """BigGAN Generator."""

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
        if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
            model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
            config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
        else:
            model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)

        try:
            resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
            resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
        except EnvironmentError:
            logger.error("Wrong model name, should be a valid path to a folder containing "
                         "a {} file and a {} file or a model name in {}".format(
                         WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
            raise

        logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))

        # Load config
        config = BigGANConfig.from_json_file(resolved_config_file)
        logger.info("Model config {}".format(config))

        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
        model.load_state_dict(state_dict, strict=False)
        return model

    def __init__(self, config):
        super(BigGAN, self).__init__()
        self.config = config
        self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
        self.generator = Generator(config)

    def forward(self, z, class_label, truncation):
        assert 0 < truncation <= 1

        embed = self.embeddings(class_label)
        cond_vector = torch.cat((z, embed), dim=1)

        z = self.generator(cond_vector, truncation)
        return z

def one_hot_from_int(int_or_list, batch_size=1):
    """ Create a one-hot vector from a class index or a list of class indices.
        Params:
            int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
            batch_size: batch size.
                If int_or_list is an int create a batch of identical classes.
                If int_or_list is a list, we should have `len(int_or_list) == batch_size`
        Output:
            array of shape (batch_size, 1000)
    """
    if isinstance(int_or_list, int):
        int_or_list = [int_or_list]

    if len(int_or_list) == 1 and batch_size > 1:
        int_or_list = [int_or_list[0]] * batch_size

    assert batch_size == len(int_or_list)

    array = np.zeros((batch_size, 1000), dtype=np.float32)
    for i, j in enumerate(int_or_list):
        array[i, j] = 1.0
    return array



In [None]:
#------------------------------------------------------------UTILS -------------------------------------------------------------
#from dall_e import map_pixels, unmap_pixels
import numpy        as np
import torchvision
import tempfile
import imageio
import random
import kornia
import shutil
import torch
import time
import os
import re


def create_outputfolder():
    outputfolder = os.path.join(os.getcwd(), 'output')
    if os.path.exists(outputfolder):
        shutil.rmtree(outputfolder)
    os.mkdir(outputfolder)

def create_strp(d, timeformat):
    return time.mktime(time.strptime(d, timeformat))

# def download_stylegan_pt():
#     cwd = os.getcwd()
#     print(cwd)
#     if 'stylegan.pt' not in os.listdir(cwd):
#         url = "https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/karras2019stylegan-ffhq-1024x1024.for_g_all.pt"
#         wget.download(url, "stylegan.pt")

def init_textfile(textfile):
    timeformat = "%M:%S"
    starttime = "00:00"
    with open(textfile, 'r') as file:
        descs = file.readlines()
        descs1 = [re.findall(r'(\d\d:\d\d) (.*)', d.strip('\n').strip())[0] for d in descs]
        if len(descs1[0]) == 0:
            descs1 = [re.findall(r'(\d\d:\d\d.\d\d) (.*)', d.strip('\n').strip())[0] for d in descs]
            timeformat = "%M:%S.%f"
            starttime = "00:00.00"
        descs1 = [(create_strp(d[0], timeformat), d[1])for d in descs1]
        firstline = (create_strp(starttime, timeformat), "start song")

        if descs1[0][0] - firstline[0]:
            descs1.insert(0, firstline)

        lastline = (descs1[-1][0]+9, "end song")
        descs1.append(lastline)
        
    return descs1

def create_image(img, i, text, gen, pre_scaled=True):
    if gen == 'stylegan':
        img = (img.clamp(-1, 1) + 1) / 2.0
        img = img[0].permute(1, 2, 0).detach().cpu().numpy() * 256
    else:
        img = np.array(img)[:,:,:]
        img = np.transpose(img, (1, 2, 0))
    if not pre_scaled:
        img = scale(img, 48*4, 32*4)
    img = np.array(img)
    with tempfile.NamedTemporaryFile() as image_temp:
        imageio.imwrite(image_temp.name+".png", img)
        image_temp.seek(0)
        return image_temp

nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

class Pars(torch.nn.Module):
    def __init__(self, gen='biggan'):
        super(Pars, self).__init__()
        self.gen = gen
        if self.gen == 'biggan':
            params1 = torch.zeros(32, 128).normal_(std=1)#.cuda()
            self.normu = torch.nn.Parameter(params1)
            params_other = torch.zeros(32, 1000).normal_(-3.9, .3)
            self.cls = torch.nn.Parameter(params_other)
            self.thrsh_lat = torch.tensor(1)#.cuda()
            self.thrsh_cls = torch.tensor(1.9)#.cuda()

        elif self.gen == 'dall-e':
            self.normu = torch.nn.Parameter(torch.zeros(1, 8192, 64, 64))#.cuda())

        elif self.gen == 'stylegan':
            latent_shape = (1, 1, 512)
            latents_init = torch.zeros(latent_shape).squeeze(-1)#.cuda()
            self.normu = torch.nn.Parameter(latents_init, requires_grad=True)

    def forward(self):
        if self.gen == 'biggan':
            return self.normu, torch.sigmoid(self.cls)

        elif self.gen == 'dall-e':
            # normu = torch.nn.functional.gumbel_softmax(self.normu.view(1, 8192, -1), dim=-1).view(1, 8192, 64, 64)
            normu = torch.nn.functional.gumbel_softmax(self.normu.view(1, 8192, -1), dim=-1, tau = 2).view(1, 8192, 64, 64)
            return normu


def pad_augs(image):
    pad = random.randint(1,50)
    pad_px = random.randint(10,90)/100
    pad_py = random.randint(10,90)/100
    pad_dims = (int(pad*pad_px), pad-int(pad*pad_px), int(pad*pad_py), pad-int(pad*pad_py))
    return torch.nn.functional.pad(image, pad_dims, "constant", 1)

def kornia_augs(image, sideX=128):
    blur = (random.randint(0,int(sideX/5))*2)+1
    kornia_model = torch.nn.Sequential(
        kornia.augmentation.RandomAffine(20, p=0.55, keepdim=True),
        kornia.augmentation.RandomHorizontalFlip(),
        kornia.augmentation.GaussianBlur((blur,blur),(blur,blur), p=0.5, border_type="constant"),
        kornia.augmentation.RandomSharpness(.5),
        kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.6)
    )
    return kornia_model(image)

def ascend_txt(model, lats, sideX, sideY, perceptor, percep, gen, tokenizedtxt):
    if gen == 'biggan':
        cutn = 128
        zs = [*lats()]
        out = model(zs[0], zs[1], 1)
        
    elif gen == 'dall-e':
        cutn = 32
        zs = lats()
        out = unmap_pixels(torch.sigmoid(model(zs)[:, :3].float()))

    elif gen == 'stylegan':
        zs = lats.normu.repeat(1,18,1)
        img = model(zs)
        img = torch.nn.functional.upsample_bilinear(img, (224, 224))

        img_logits, _text_logits = perceptor(img, tokenizedtxt.cuda())

        return 1/img_logits * 100, img, zs
    
    p_s = []
    for ch in range(cutn):
        # size = int(sideX*torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95))
        size = int(sideX*torch.zeros(1,).normal_(mean=.39, std=.865).clip(.362, .7099))
        offsetx = torch.randint(0, sideX - size, ())
        offsety = torch.randint(0, sideY - size, ())
        apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]
        apper = pad_augs(apper)
        # apper = kornia_augs(apper, sideX=sideX)
        apper = torch.nn.functional.interpolate(apper, (224, 224), mode='nearest')
        p_s.append(apper)
    into = torch.cat(p_s, 0)
    if gen == 'biggan':
        # into = nom((into + 1) / 2)
        up_noise = 0.01649
        into = into + (up_noise)*torch.randn_like(into, requires_grad=True)
        into = nom((into + 1) / 1.8)
    elif gen == 'dall-e':
        into = nom((into + 1) / 2)
    iii = perceptor.encode_image(into)
    llls = zs #lats()

    if gen == 'dall-e':
        return [0, 10*-torch.cosine_similarity(percep, iii).view(-1, 1).T.mean(1), zs]

    lat_l = torch.abs(1 - torch.std(llls[0], dim=1)).mean() + \
            torch.abs(torch.mean(llls[0])).mean() + \
            4*torch.max(torch.square(llls[0]).mean(), lats.thrsh_lat)
    for array in llls[0]:
        mean = torch.mean(array)
        diffs = array - mean
        var = torch.mean(torch.pow(diffs, 2.0))
        std = torch.pow(var, 0.5)
        zscores = diffs / std
        skews = torch.mean(torch.pow(zscores, 3.0))
        kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0
        lat_l = lat_l + torch.abs(kurtoses) / llls[0].shape[0] + torch.abs(skews) / llls[0].shape[0]
    cls_l = ((50*torch.topk(llls[1],largest=False,dim=1,k=999)[0])**2).mean()
    return [lat_l, cls_l, -100*torch.cosine_similarity(percep, iii, dim=-1).mean(), zs]

def train(i, model, lats, sideX, sideY, perceptor, percep, optimizer, text, tokenizedtxt, epochs=200, gen='biggan', img=None):
    loss1 = ascend_txt(model, lats, sideX, sideY, perceptor, percep, gen, tokenizedtxt)
    if gen == 'biggan':
        loss = loss1[0] + loss1[1] + loss1[2]
        zs = loss1[3]
    elif gen == 'dall-e':
        loss = loss1[0] + loss1[1]
        loss = loss.mean()
        zs = loss1[2]
    elif gen == 'stylegan':
        loss = loss1[0]
        img  = loss1[1].cpu()
        zs = loss1[2]


    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if i+1 == epochs:
        # if it's the last step, return the final z
        return zs
    return False

In [None]:
#------------------------------------------------------------CREATE VIDEO -------------------------------------------------------------
# from moviepy.config import change_settings
# change_settings({"IMAGEMAGICK_BINARY": "/usr/local/lib/"})
from textwrap3  import fill
import moviepy.editor as me
import tempfile
import textwrap
import glob
import os

def createvid(description, image_temp_list, fps=24, duration=0.1):
    blackbg = me.ColorClip((720,720), (0, 0, 0))

    clips = [me.ImageClip(m.name+".png", duration=duration) for m in image_temp_list]
    for img in image_temp_list:
        img.close()
    concat_clip = me.concatenate_videoclips(clips, method="compose").set_position(('center', 'center'))
    if description == "start song":
        description = " "
    if len(description) > 35:
        description = fill(description, 35)

    #asd lyrics
    txtClip = me.TextClip(description, color='white', fontsize=30, font='Amiri-regular').set_position('center')
    txt_col = txtClip.on_color(size=(blackbg.w, txtClip.h + 10),
                               color=(0,0,0), pos=('center', 'center'), 
                               col_opacity=0.8)

    txt_mov = txt_col.set_position((0, blackbg.h-20-txtClip.h))
    comp_list = [blackbg, concat_clip, txt_mov]
    final = me.CompositeVideoClip(comp_list).set_duration(concat_clip.duration)

    with tempfile.NamedTemporaryFile() as video_tempfile:
        final.write_videofile(video_tempfile.name+".mp4", fps=fps)
        video_tempfile.seek(0)

        for clip in clips:
            clip.close()
        for clip in comp_list:
            clip.close()
        return video_tempfile

def concatvids(descriptions, video_temp_list, audiofilepath, fps=24, lyrics=True):
    clips = []

    for idx, (desc, vid) in enumerate(zip(descriptions, video_temp_list)):
    #     # compvid_list = []
    #     # desc = desc[1]
        if desc == descriptions[-1][1]:
            break
    #     # elif desc == "start song":
    #     #     desc = " "
    #     # compvid_list.append(blackbg)
        vid = me.VideoFileClip(f'{vid.name}.mp4')#.set_position(('center', 'center'))
        # compvid_list.append(vid)

        # if len(desc) > 35:
        #     desc = fill(desc, 35)
        # if lyrics:

        #     txtClip = me.TextClip(desc, color='white', fontsize=30, font='Amiri-regular').set_position('center')
        #     txt_col = txtClip.on_color(size=(blackbg.w, txtClip.h + 10),
        #               color=(0,0,0), pos=('center', 'center'), col_opacity=0.8)
            
        #     txt_mov = txt_col.set_position((0, blackbg.h-20-txtClip.h))
        #     compvid_list.append(txt_mov)

        # video_tempfile = tempfile.NamedTemporaryFile()
        
        # final = me.CompositeVideoClip(compvid_list).set_duration(vid.duration)
        # final.write_videofile(video_tempfile.name+".mp4", fps=fps)
        # video_tempfile.seek(0)
        # for clip in clips:
        #     clip.close()
        # concat_clip.close()


        clips.append(vid)

    concat_clip = me.concatenate_videoclips(clips, method="compose").set_position(('center', 'center'))
    # concat_clip = me.CompositeVideoClip([blackbg, concat_clip])#.set_duration(vid.duration)
    if audiofilepath:
        concat_clip.audio = me.AudioFileClip(audiofilepath)

        concat_clip.duration = concat_clip.audio.duration
    concat_clip.write_videofile(os.path.join('output', f"finaloutput.mp4"), fps=fps)


In [None]:
#----------------------------------------------------------- MAIN -------------------------------------------------------------

#from utils import train, Pars, create_image, create_outputfolder, init_textfile
#from dall_e import map_pixels, unmap_pixels, load_model
#from stylegan import g_synthesis
#from biggan import BigGAN
#from pytorch_pretrained_biggan import BigGAN

from tqdm import tqdm
#import create_video
import tempfile
import argparse
import torch
import glob
import os
import math
#import imageio
import clip 
#print(clip.__version__)

def main():
    # Automatically creates 'output' folder
    create_outputfolder()
    
    #clip
    print("loading clip model...")
    perceptor, preprocess  = clip.load('ViT-B/32')
    perceptor = perceptor.eval()
    #torch.save(perceptor, "vit-b-32.pt")
    print("model loaded")


    # Load the model
    if generator == 'biggan':
        if os.path.isfile('gdrive/MyDrive/biggan-clip/biggan_deep_512.pt'):
          print("biggan found")
          model = torch.load('gdrive/MyDrive/biggan-clip/biggan_deep_512.pt')
          print("biggan loaded")
        else:
          print("loading biggan...")
          model = BigGAN.from_pretrained('biggan-deep-512')
          torch.save(model , 'gdrive/MyDrive/biggan-clip/biggan_deep_512.pt')

        model   = model.eval()#.cuda()
    elif generator == 'dall-e':
        model   = load_model("decoder.pkl")# 'cuda')
    elif generator == 'stylegan':
        model   = g_synthesis.eval()#.cuda()

    # Read the textfile 
    # descs - list to append the Description and Timestamps
    descs = init_textfile(textfile)

    # list of temporary PTFiles 
    templist = []

    # Loop over the description list
    for d in tqdm(descs):
        timestamp = d[0]
        line = d[1]
        # stamps_descs_list.append((timestamp, line))
        lats = Pars(gen=generator)#.cuda()
         # Init Generator's latents
        if generator == 'biggan':
            par     = lats.parameters()
            lr      = 0.1#.07
        elif generator == 'stylegan':
            par     = [lats.normu]
            lr      = .01
        elif generator == 'dall-e':
            par  = [lats.normu]
            lr   = .1

        # Init optimizer
        optimizer = torch.optim.Adam(par, lr)

        # tokenize the current description with clip and encode the text
        txt = clip.tokenize(line)
        percep = perceptor.encode_text(txt).detach().clone() #txt.cuda

        # Training Loop
        print('training')
        for i in range(epochs):
            print(i)
            zs = train(i, model, lats, sideX, sideY, perceptor, percep, optimizer, line, txt, epochs=epochs, gen=generator)
            
        # save each line's last latent to a torch file temporarily
        latent_temp = tempfile.NamedTemporaryFile()
        torch.save(zs, latent_temp) #f'./output/pt_folder/{line}.pt')
        latent_temp.seek(0)
        #append it to templist so it can be accessed later
        templist.append(latent_temp)
    return templist, descs, model

def sigmoid(x):
    x = x * 2. - 1.
    return math.tanh(1.5*x/(math.sqrt(1.- math.pow(x, 2.)) + 1e-6)) / 2 + .5

def interpolate(templist, descs, model, audiofile):
    file1 = open("gdrive/MyDrive/biggan-clip/info.txt","w")
    video_temp_list = []
    # interpole elements between each image
    for idx1, pt in enumerate(descs):
        # get the next index of the descs list, 
        # if it z1_idx is out of range, break the loop
        z1_idx = idx1 + 1
        if z1_idx >= len(descs):
            break
        current_lyric = pt[1]

        # get the interval betwee 2 lines/elements in seconds `ttime`
        d1 = pt[0]
        d2 = descs[z1_idx][0]
        ttime = d2 - d1
        # if it is the very first index, load the first pt temp file
        # if not assign the previous pt file (z1) to zs variable
        if idx1 == 0:
            zs = torch.load(templist[idx1])
        else:
            zs = z1
        
        # compute for the number of elements to be insert between the 2 elements
        N = round(ttime * interpol)
        print("idx", idx1,"ttime", ttime, "n", N )

        # the codes below determine if the output is list (for biggan)
        # if not insert it into a list 
        if not isinstance(zs, list):
            z0 = [zs]
            z1 = [torch.load(templist[z1_idx])]
        else:
            z0 = zs
            z1 = torch.load(templist[z1_idx])
        
        # loop over the range of elements and generate the images
        image_temp_list = []
        for t in range(N):
            azs = []
            for r in zip(z0, z1):
                z_diff = r[1] - r[0] 
                inter_zs = r[0] + sigmoid(t / (N-1)) * z_diff
                azs.append(inter_zs)

            # Generate image
            with torch.no_grad():
                if generator == 'biggan':
                    img = model(azs[0], azs[1], 1).cpu().numpy()
                    img = img[0]
                elif generator == 'dall-e':
                    img = unmap_pixels(torch.sigmoid(model(azs[0])[:, :3]).cpu().float()).numpy()
                    img = img[0]
                elif generator == 'stylegan':
                    img = model(azs[0])
                
                # img= np.transpose(img,(1,2,0))
                # img = np.array(img)
                # imageio.imwrite(f'gdrive/MyDrive/biggan-clip/out2/frame{t}', img)
                image_temp = create_image(img, t, current_lyric, generator)  
                file1.writelines([f'{current_lyric}', " ", str(ttime/N),"\n"])     
            image_temp_list.append(image_temp)
            
        # video_temp = createvid(f'{current_lyric}', image_temp_list, duration=ttime / N)
        # video_temp_list.append(video_temp)

        file1.writelines([f'{current_lyric}', " ", str(ttime/N),"\n"])
    file1.close() #to change file access modes

    #Finally create the final output and save to output folder
    print(descs, "___________________________________________", lyrics )
    #concatvids(descs, video_temp_list, audiofile, lyrics=lyrics)

    

In [None]:
generator='biggan'
epochs      = 1
textfile    = 'gdrive/MyDrive/biggan-clip/test_lyrics.txt'
audiofile   = 'gdrive/MyDrive/biggan-clip/Weasle_sample_audio.mp3'
interpol    = 2
lyrics      = True
sideX       = 512
sideY       = 512
templist, descs, model = main()
interpolate(templist, descs, model, audiofile)
# tedad img= sanie * interpol * (line +2)?


loading clip model...
model loaded
biggan found
biggan loaded


  0%|          | 0/4 [00:00<?, ?it/s]

training
0


 25%|██▌       | 1/4 [01:14<03:42, 74.27s/it]

training
0


 50%|█████     | 2/4 [02:25<02:25, 72.68s/it]

training
0


 75%|███████▌  | 3/4 [03:37<01:12, 72.26s/it]

training
0


100%|██████████| 4/4 [04:47<00:00, 72.00s/it]


idx 0 ttime 2.0 n 4




idx 1 ttime 2.0 n 4




idx 2 ttime 9.0 n 18




[(-2208988800.0, 'start song'), (-2208988798.0, "I'm curling in my sleeping bag"), (-2208988796.0, 'all dressed up in my shower cap'), (-2208988787.0, 'end song')] ___________________________________________ True


In [None]:
from textwrap3  import fill
import moviepy.editor as me
import tempfile
import textwrap
import glob
import os

def create_vid(path = 'gdrive/MyDrive/biggan-clip/out', description ):
    blackbg = me.ColorClip((720,720), (0, 0, 0))
    clips = [me.ImageClip(m, duration=duration) for m in os.listdir(path)]
    concat_clip = me.concatenate_videoclips(clips, method="compose").set_position(('center', 'center'))
    
    if len(description) > 35:
        description = fill(description, 35)

    txtClip = me.TextClip(description, color='white', fontsize=30, font='Amiri-regular').set_position('center')
    txt_col = txtClip.on_color(size=(blackbg.w, txtClip.h + 10),
                               color=(0,0,0), pos=('center', 'center'), 
                               col_opacity=0.8)

    txt_mov = txt_col.set_position((0, blackbg.h-20-txtClip.h))
    comp_list = [blackbg, concat_clip, txt_mov]
    final = me.CompositeVideoClip(comp_list).set_duration(concat_clip.duration)

    with tempfile.NamedTemporaryFile() as video_tempfile:
        final.write_videofile(video_tempfile.name+".mp4", fps=fps)
        video_tempfile.seek(0)

        for clip in clips:
            clip.close()
        for clip in comp_list:
            clip.close()
        return video_tempfile

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# Program to show various ways to read and
# write data in a file.
file1 = open("gdrive/MyDrive/biggan-clip/myfile.txt","w")
L = ["This is Delhi ","This is Paris ","This is London "]

# \n is placed to indicate EOL (End of Line)
for i in range(5):
  file1.writelines(L)
  #file1.writelines([f'current_lyrics', " ", str(ttime/N)," ", str(duration), "\n"])
file1.close() #to change file access modes


# file1 = open("myfile.txt","r+")

# print("Output of Read function is ")
# print(file1.read())
# print()

# # seek(n) takes the file handle to the nth
# # bite from the beginning.
# file1.seek(0)

# print( "Output of Readline function is ")
# print(file1.readline())
# print()

# file1.seek(0)

# # To show difference between read and readline
# print("Output of Read(9) function is ")
# print(file1.read(9))
# print()

# file1.seek(0)

# print("Output of Readline(9) function is ")
# print(file1.readline(9))

# file1.seek(0)
# # readlines function
# print("Output of Readlines function is ")
# print(file1.readlines())
# print()
# file1.close()


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
