# Mini Imagenet Generation

This is one of a pair of notebooks used for generating an ImageNet-like dataset of training
data using stable diffusion models. The difficulty of such artificial datasets can be
easily tuned, and they are useful for debugging and testing deep learning applications.

The first notebook uses Mistral-7B for taking class labels and generating descriptive prompts
for image generation. The prompts are written out as shards to disk and shuffled. The process
is parallelized using Ray.

The second notebook uses Stable Diffustion to take descriptive prompts/image captions
and renders them as image. This is a straightfowrard shard-to-shard transformation.

Note that we are using explicit parallelization over shard files in the initial generation
and the image generation, while we are using ray.data for the actual shuffling. That is
because using explicit parallelization over shards makes it easier to restart jobs that have
failed halfway through for some reason.

In [None]:
import itertools, random, uuid
from pprint import pprint
import os

import torch
import webdataset as wds
from transformers import AutoModelForCausalLM, AutoTokenizer
from webdataset import filters
import textwrap
import time

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List

def take(n, iterable):
    """Return first n items of the iterable as a list"""
    return list(itertools.islice(iterable, n))


def get_gpu_memories():
    memory = []
    if torch.cuda.is_available(): 
        for i in range(torch.cuda.device_count()):
            memory.append(torch.cuda.get_device_properties(i).total_memory - 
                                torch.cuda.memory_allocated(i))
    return memory

def get_num_gpus():
    cluster_resources = ray.cluster_resources()
    return cluster_resources["GPU"]

def ray_get(future, timeout=0.1):
    ready, not_ready = ray.wait([future], timeout=timeout)
    if not_ready:
        raise TimeoutError()
    return ray.get(future)

def is_ready(actor, timeout=0.1):
    ready, not_ready = ray.wait([actor], timeout=timeout)
    if not_ready:
        return False
    return True

def select_or_delete(actors, predicate):
    result = []
    for actor in actors:
        if predicate(actor):
            result.append(actor)
        else:
            del actor
    return result

In [None]:
# parameters

# number of classes, must be 10, 100, or 1000
nclasses = 10 

# number of images per shard 
nimages = 100 

# number of prompts generated at once per class
ngenerated = 20 

# number of training shards 
nshards = 1281  

# number of validation shards
nvalshards = 50  

# output directory
odir = f"./mini-imagenet-{nclasses}" 

# output file prefix
oprefix = f"mi{nclasses}" 

# number of actors to use, -1 for =number of GPUs
nactors = -1  

# check that each actor has sufficient memory
check_sufficient = True 

# seconds to wait for actors to start up
actor_startup_wait = 10 

In [None]:
!echo "odir=$odir"
!mkdir -p $odir

In [None]:
if nclasses == 10:
    imagenet_classes = "dog cat car plane bird fish frog horse sheep truck".split()
elif nclasses == 100:
    imagenet_classes = sorted(list(set("""
    3d_printer aircraft_carrier airplane apple backpack banana baseball_bat
    baseball_glove bat bear bed bench bird book bottle bowl broccoli cake camel car
    carrot cat cell_phone chair clock cloud couch cup dining_table dog donut
    elephant fire fish fork fox frisbee frog giraffe hair_drier handbag horse
    hot_dog hydrant kangaroo keyboard kite knife lamp laptop lion meteor microwave
    monitor monkey mouse mushroom octopus orange oven palm_tree panda parking_meter
    pear pizza plane potted_plant refrigerator remote rocket sandwich scissors sheep
    sink skateboard skis snowboard spoon sports_ball stop_sign street_sign suitcase
    surfboard sweet_pepper table teddy_bear telephone tennis_racket tie tiger
    toaster toilet toothbrush tree truck tv umbrella vase wine_glass zebra
    """.split())))
elif nclasses == 1000:
    imagenet_classes = open("imagenet1000.txt").read().split()
else:
    raise ValueError(f"invalid number of classes: {nclasses}, must be 10, 100, or 1000")

assert len(imagenet_classes) == nclasses

# Generation Classes

We encapsulate the model and the generation in a low level and high level class. We can then instantiate those classes once per GPU and call them to generate the shards.

In [None]:
class TextGenerationModel:
    def __init__(self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1", temperature: float = 2.0, top_p: float = 0.9, top_k: int = 10, max_length: int = 96, num_return_sequences: int = 10):
        """
        Initialize the text generation model.

        Args:
            model_name: The name of the pretrained model.
            temperature: The temperature for the generation process.
            top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling.
            top_k: The number of highest probability vocabulary tokens to keep for top-k filtering.
            max_length: The maximum length of the sequence to be generated.
            num_return_sequences: The number of independently computed returned sequences for each element in the batch.
        """
        # Load the tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            padding_side="left",
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(model_name)

        # Ensure the model is on GPU
        self.model.to("cuda").half()

        # Set generation parameters
        self.temperature = temperature
        self.top_p = top_p
        self.top_k = top_k
        self.max_length = max_length
        self.num_return_sequences = num_return_sequences

    def generate_responses(self, texts: List[str]) -> List[str]:
        """
        Generate responses for the given texts.

        Args:
            texts: A list of texts to generate responses for.

        Returns:
            A list of generated responses.
        """
        # Prepare the inputs
        inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to("cuda")
        
        # Generate responses
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                do_sample=True,
                temperature=self.temperature,
                top_p=self.top_p,
                top_k=self.top_k,
                max_length=self.max_length,
                num_return_sequences=self.num_return_sequences,
            )
        
        # Decode the responses
        responses = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        
        return responses    

In [None]:
from typing import Dict, List, Iterator

class CaptionGenerator:
    def __init__(self):
        self.template = "[INST] Generate a random, detailed visual caption/description of a photo showing: {object}. [/INST]"

    def load_model(self):
        self.model = TextGenerationModel()

    def gpu_is_sufficient(self):
        gpu_memories = get_gpu_memories()
        assert len(gpu_memories) == 1, "more than one GPU allocated to actor???"
        return gpu_memories[0] / 1e9 > 32.0

    def process_batch(self, batch: List[Dict], trim: bool = True) -> List[Dict]:
        """Process a batch of samples, generating responses for each."""
        n = len(batch)
        texts = [batch[i]["text"] for i in range(n)]
        responses = self.model.generate_responses(texts)
        if trim:
            responses = [response.split("[/INST]")[-1].strip() for response in responses]
        responses = [responses[i : i + self.model.num_return_sequences] for i in range(0, len(responses), self.model.num_return_sequences)]
        for i in range(n):
            batch[i]["responses"] = responses[i]
        return batch

    def process_list_by_batches(self, samples: List[Dict], batch_size: int = 1) -> Iterator[Dict]:
        """Process a list of samples by batches."""
        samples_iter = iter(samples)
        while True:
            batch = take(batch_size, samples_iter)
            if not batch:
                break
            responses = self.process_batch(batch)
            yield from responses

    def make_samples(self, n: int) -> Iterator[Dict]:
        """Generate a list of samples."""
        for i in range(n):
            cls = random.randrange(len(imagenet_classes))
            object = imagenet_classes[cls]
            yield dict(cls=cls, object=object, text=self.template.format(object=object))

    def make_captions(self, samples: List[Dict]) -> Iterator[Dict]:
        """Generate captions for a list of samples."""
        for sample in self.process_list_by_batches(samples):
            for response in sample["responses"]:
                yield dict(
                    cls=sample["cls"],
                    object=sample["object"],
                    text=sample["text"],
                    response=response,
                )

    def make_shard(self, output: str, n: int, k: int = 5):
        """
        Generate a shard of samples with generated captions.

        Args:
            output: The output file to write the shard to.
            n: The number of samples to generate in the shard.
            k: The number of return sequences for each sample.
        """
        if os.path.exists(output):
            return
        self.model.num_return_sequences = k
        writer = wds.TarWriter(output+".temp")
        captions = self.make_captions(self.make_samples(n // k + k))
        for caption in itertools.islice(captions, n):
            sample = dict(
                __key__=uuid.uuid4().hex,
                json=caption,
            )
            writer.write(sample)
        os.rename(output+".temp", output)


# Parallelization with Ra

For parallel generation, we use a Ray cluster. This will also do the right thing with a single machine/single GPU setup. It automatically scales up.

In [None]:

import ray

if not ray.is_initialized():
    ray.init()


In [None]:
@ray.remote(num_gpus=1)
class RayCaptionGenerator(CaptionGenerator):
    def __init__(self):
        super().__init__()

In [None]:
# Start up and create the actor pool.
# This tries to adapt to the number of GPUs available.
# It also checks that each actor has sufficient memory.
# If not, set up your cluster differently by excluding GPUs that are too small.
# (Ray's facilities for heterogenous clusters are somewhat limited)

ngpus = get_num_gpus() if nactors == -1 else nactors

print(f"using {ngpus} actors")
actors = [RayCaptionGenerator.remote() for i in range(int(ngpus))]

print("loading the models")
for actor in actors:
    assert ray.get(actor.gpu_is_sufficient.remote()), "GPU memory insufficient"
    ray.get(actor.load_model.remote())

print("creating the pool")
pool = ray.util.ActorPool(actors)

In [None]:
# It would be nice if there were a .map_with_actors method in pool,
# but there isn't, so we use this workaround.

def apply_actor(actor, dest):
    return actor.make_shard.remote(dest, nimages, ngenerated)

!mkdir -p $odir/prompts

In [None]:
# Perform the actual shard generation.

dests = [f"{odir}/prompts/prompts-{i:06d}.tar" for i in range(nshards + nvalshards)]
result = list(pool.map(apply_actor, dests))

In [None]:
del actors
del pool

# Shuffle

For shuffling the dataset, we use the ray.data `read_webdataset` and `write_webdataset` functions.

In [None]:
import ray
from ray.data import read_webdataset
import glob

In [None]:
!mkdir -p $odir/shuffled
!rm -f $odir/shuffled/*
shards = glob.glob(f"{odir}/prompts/prompts-*.tar")
dataset = read_webdataset(shards)
shuffled_dataset = dataset.random_shuffle()
shuffled_dataset.repartition(len(shards)).write_webdataset(f"{odir}/shuffled/")

In [None]:
# The output of write_webdataset is a directory of shards, but not following
# the usual naming conventions. We rename the shards to follow typical
# webdataset conventions.

import glob
import os

shuffled = sorted(glob.glob(f"{odir}/shuffled/*.tar"))
for i in range(nshards):
    os.rename(shuffled[i], f"{odir}/shuffled/{oprefix}-{i:06d}.tar")
for i in range(nvalshards):
    os.rename(shuffled[nshards+i], f"{odir}/shuffled/{oprefix}-val-{i:06d}.tar")
